[shardformer] fix pipeline forward error if custom layer distribution is used (#5189)

* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution

* Change static methods for t5 layer distribution to member functions

* Change static methods for whisper layer distribution to member functions

* Replace whisper policy usage with self one

* Fix test case to use non-static layer distribution methods

* fix: fix typo

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
Insu Jang
2024-03-27 01:57:00 -04:00
committed by GitHub
parent e6707a6e8d
commit 00525f7772
18 changed files with 136 additions and 106 deletions

View File

@@ -98,11 +98,11 @@ class OpenMoePolicy(Policy):
module = self.model.model
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
return
@@ -126,12 +126,9 @@ class OpenMoePolicy(Policy):
held_layers.append(module.norm)
return held_layers
@staticmethod
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
"""Divide layers into stages
"""
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
"""Divide layers into stages"""
if num_layers == 24 and num_stages == 4:
return [7, 7, 7, 3]
elif num_layers == 24 and num_stages == 2:
@@ -142,7 +139,7 @@ class OpenMoePolicy(Policy):
return [8, 4]
else:
print(f"num_layers: {num_layers}, num_stages: {num_stages} not optimized, use origin pp policy")
return Policy.distribute_layers(num_layers, num_stages)
return super().distribute_layers(num_layers, num_stages)
class OpenMoeModelPolicy(OpenMoePolicy):