mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[shardformer] fix pipeline forward error if custom layer distribution is used (#5189)
* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
@@ -241,9 +243,8 @@ class T5BasePolicy(Policy):
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
@staticmethod
|
||||
def distribute_t5_layers(
|
||||
num_encoder_layers: int, num_decoder_layers: int, num_stages: int
|
||||
self, num_encoder_layers: int, num_decoder_layers: int, num_stages: int
|
||||
) -> Tuple[List[int], int]:
|
||||
"""
|
||||
Distribute t5 layers into stages when pipeline parallel is used.
|
||||
@@ -261,7 +262,7 @@ class T5BasePolicy(Policy):
|
||||
|
||||
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
|
||||
if num_decoder_layers == 0:
|
||||
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
|
||||
return self.distribute_layers(num_encoder_layers, num_stages), num_stages
|
||||
|
||||
# the number of stages distributed between encoder and decoder is optimized in this way:
|
||||
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
|
||||
@@ -272,22 +273,21 @@ class T5BasePolicy(Policy):
|
||||
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
|
||||
num_decoder_stages = num_stages - num_encoder_stages
|
||||
|
||||
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
|
||||
decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
|
||||
encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages)
|
||||
decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages)
|
||||
return encoder_distribution + decoder_distribution, num_encoder_stages
|
||||
|
||||
@staticmethod
|
||||
def get_t5_stage_index(
|
||||
layers_per_stage: List[int], stage: int, decoder_starting_stage: int
|
||||
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
|
||||
) -> Tuple[bool, int, int]:
|
||||
"""
|
||||
Input the distribution of layers among stages, the current stage and the first stage of decoder.
|
||||
Return the starting/ending idx of layers in encoder/decoder
|
||||
"""
|
||||
if stage < decoder_starting_stage:
|
||||
return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
|
||||
return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
|
||||
else:
|
||||
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
|
||||
return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
@@ -302,12 +302,10 @@ class T5BasePolicy(Policy):
|
||||
num_decoder_layers = len(decoder.block) if decoder else 0
|
||||
|
||||
held_layers = []
|
||||
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
|
||||
)
|
||||
start_idx, end_idx = T5BasePolicy.get_t5_stage_index(
|
||||
layers_per_stage, stage_manager.stage, decoder_starting_stage
|
||||
)
|
||||
start_idx, end_idx = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
|
||||
|
||||
if stage_manager.stage < decoder_starting_stage:
|
||||
# current stage is in t5's encoder
|
||||
@@ -343,10 +341,10 @@ class T5BasePolicy(Policy):
|
||||
num_encoder_layers = len(encoder.block)
|
||||
num_decoder_layers = len(decoder.block) if decoder else 0
|
||||
|
||||
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
layers_per_stage, decoder_starting_stage = self.distribute_t5_layers(
|
||||
num_encoder_layers, num_decoder_layers, stage_manager.num_stages
|
||||
)
|
||||
stage_index = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
|
||||
stage_index = self.get_t5_stage_index(layers_per_stage, stage_manager.stage, decoder_starting_stage)
|
||||
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
@@ -386,7 +384,7 @@ class T5ModelPolicy(T5BasePolicy):
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager is not None and stage_manager.num_stages > 1:
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
_, decoder_starting_stage = self.distribute_t5_layers(
|
||||
len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages
|
||||
)
|
||||
|
||||
@@ -434,7 +432,7 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager is not None and stage_manager.num_stages > 1:
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
_, decoder_starting_stage = self.distribute_t5_layers(
|
||||
len(module.encoder.block), len(module.decoder.block), stage_manager.num_stages
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user