[shardformer] fix pipeline forward error if custom layer distribution is used (#5189)

* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution

* Change static methods for t5 layer distribution to member functions

* Change static methods for whisper layer distribution to member functions

* Replace whisper policy usage with self one

* Fix test case to use non-static layer distribution methods

* fix: fix typo

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
Insu Jang
2024-03-27 01:57:00 -04:00
committed by GitHub
parent e6707a6e8d
commit 00525f7772
18 changed files with 136 additions and 106 deletions

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import warnings
from functools import partial
from typing import Callable, Dict, List, Tuple
@@ -241,9 +243,8 @@ class T5BasePolicy(Policy):
def postprocess(self):
return self.model
@staticmethod
def distribute_t5_layers(
num_encoder_layers: int, num_decoder_layers: int, num_stages: int
self, num_encoder_layers: int, num_decoder_layers: int, num_stages: int
) -> Tuple[List[int], int]:
"""
Distribute t5 layers into stages when pipeline parallel is used.
@@ -261,7 +262,7 @@ class T5BasePolicy(Policy):
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0:
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
return self.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optimized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
@@ -272,22 +273,21 @@ class T5BasePolicy(Policy):
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages
@staticmethod
def get_t5_stage_index(
layers_per_stage: List[int], stage: int, decoder_starting_stage: int
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
) -> Tuple[bool, int, int]:
"""
Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder
"""
if stage < decoder_starting_stage:
return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else:
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage."""
@@ -302,12 +302,10 @@ class T5BasePolicy(Policy):
num_decoder_layers = len(decoder.block) if decoder else 0
held_layers = []
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
start_idx, end_idx = T5BasePolicy.get_t5_stage_index(
layers_per_stage, stage_manager.stage, decoder_starting_stage
)
start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
if stage_manager.stage < decoder_starting_stage:
# current stage is in t5's encoder
@@ -343,10 +341,10 @@ class T5BasePolicy(Policy):
num_encoder_layers = len(encoder.block)
num_decoder_layers = len(decoder.block) if decoder else 0
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
)
stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
stage_index = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
method_replacement = {
"forward": partial(
@@ -386,7 +384,7 @@ class T5ModelPolicy(T5BasePolicy):
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
_, decoder_starting_stage = self.distribute_t5_layers(
len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages
)
@@ -434,7 +432,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
_, decoder_starting_stage = self.distribute_t5_layers(
len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages
)