[pipeline] add pipeline support for all T5 models (#4310)

* complete policy for T5Model & T5ForConditionalGeneration

* modify function signature in forwards

* add forward for T5model

* add forward for T5ForConditionalGeneration

* fix a bug

* fix hidden_states transporting in decoder

* fix the passing of encoder_outputs
This commit is contained in:
Baizhou Zhang
2023-07-25 14:45:33 +08:00
committed by Hongxin Liu
parent d0807122e2
commit 083d7da33d
3 changed files with 388 additions and 19 deletions

View File

@@ -293,21 +293,42 @@ class T5BasePolicy(Policy):
class T5ModelPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5Model
base_policy = super().module_policy()
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
policy=base_policy,
policy=policy,
target_key=T5Model)
return base_policy
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=T5Model, new_forward=T5PipelineForwards.t5_model_forward, policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
return super().get_held_layers()
def get_shared_params(self) -> List[Dict[int, Tensor]]:
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
len(module.decoder.block),
stage_manager.num_stages)
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}]
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]}
for k, v in binding_map.items():
src = getattr_(self.model, k)
@@ -318,6 +339,9 @@ class T5ModelPolicy(T5BasePolicy):
class T5ForConditionalGenerationPolicy(T5BasePolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import T5ForConditionalGeneration
@@ -335,8 +359,38 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
],
policy=policy,
target_key=T5ForConditionalGeneration)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=T5ForConditionalGeneration,
new_forward=T5PipelineForwards.t5_for_conditional_generation_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
len(module.decoder.block),
stage_manager.num_stages)
shared_params = []
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
shared_params.append({
0: module.shared.weight,
decoder_starting_stage: module.decoder.embed_tokens.weight
})
if id(module.lm_head.weight) == id(module.shared.weight):
shared_params.append({0: module.shared.weight, stage_manager.num_stages - 1: module.lm_head.weight})
return shared_params
return []
def postprocess(self):
super().postprocess()
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
@@ -382,7 +436,7 @@ class T5EncoderPolicy(T5BasePolicy):
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]}
for k, v in binding_map.items():
src = getattr_(self.model, k)