mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
[shardformer] fix pipeline forward error if custom layer distribution is used (#5189)
* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
@@ -188,7 +188,7 @@ class GPT2Policy(Policy):
|
||||
layers_per_stage = self.distribute_layers(
|
||||
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||
)
|
||||
stage_indices = Policy.get_stage_index(
|
||||
stage_indices = self.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
@@ -229,7 +229,7 @@ class GPT2Policy(Policy):
|
||||
layers_per_stage = self.distribute_layers(
|
||||
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||
)
|
||||
stage_manager.stage_indices = Policy.get_stage_index(
|
||||
stage_manager.stage_indices = self.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
@@ -243,8 +243,8 @@ class GPT2Policy(Policy):
|
||||
)
|
||||
}
|
||||
else:
|
||||
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward,
|
||||
|
Reference in New Issue
Block a user