mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[pipeline] add pipeline support for all T5 models (#4310)
* complete policy for T5Model & T5ForConditionalGeneration * modify function signature in forwards * add forward for T5model * add forward for T5ForConditionalGeneration * fix a bug * fix hidden_states transporting in decoder * fix the passing of encoder_outputs
This commit is contained in:
committed by
Hongxin Liu
parent
d0807122e2
commit
083d7da33d
@@ -293,21 +293,42 @@ class T5BasePolicy(Policy):
|
||||
|
||||
class T5ModelPolicy(T5BasePolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5Model
|
||||
base_policy = super().module_policy()
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
policy=base_policy,
|
||||
policy=policy,
|
||||
target_key=T5Model)
|
||||
return base_policy
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=T5Model, new_forward=T5PipelineForwards.t5_model_forward, policy=policy)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
return super().get_held_layers()
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager is not None and stage_manager.num_stages > 1:
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
|
||||
len(module.decoder.block),
|
||||
stage_manager.num_stages)
|
||||
|
||||
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
|
||||
return [{0: module.shared.weight, decoder_starting_stage: module.decoder.embed_tokens.weight}]
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"shared.weight": ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]}
|
||||
for k, v in binding_map.items():
|
||||
src = getattr_(self.model, k)
|
||||
@@ -318,6 +339,9 @@ class T5ModelPolicy(T5BasePolicy):
|
||||
|
||||
class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
@@ -335,8 +359,38 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
],
|
||||
policy=policy,
|
||||
target_key=T5ForConditionalGeneration)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=T5ForConditionalGeneration,
|
||||
new_forward=T5PipelineForwards.t5_for_conditional_generation_forward,
|
||||
policy=policy)
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
held_layers = super().get_held_layers()
|
||||
if self.pipeline_stage_manager.is_last_stage():
|
||||
held_layers.append(self.model.lm_head)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager is not None and stage_manager.num_stages > 1:
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(len(module.encoder.block),
|
||||
len(module.decoder.block),
|
||||
stage_manager.num_stages)
|
||||
|
||||
shared_params = []
|
||||
if id(module.decoder.embed_tokens.weight) == id(module.shared.weight):
|
||||
shared_params.append({
|
||||
0: module.shared.weight,
|
||||
decoder_starting_stage: module.decoder.embed_tokens.weight
|
||||
})
|
||||
if id(module.lm_head.weight) == id(module.shared.weight):
|
||||
shared_params.append({0: module.shared.weight, stage_manager.num_stages - 1: module.lm_head.weight})
|
||||
return shared_params
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
super().postprocess()
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
@@ -382,7 +436,7 @@ class T5EncoderPolicy(T5BasePolicy):
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"shared.weight": ["encoder.embed_tokens.weight"]}
|
||||
for k, v in binding_map.items():
|
||||
src = getattr_(self.model, k)
|
||||
|
Reference in New Issue
Block a user