mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-26 12:14:02 +00:00
[shardformer] fix pipeline forward error if custom layer distribution is used (#5189)
* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
@@ -167,7 +167,7 @@ class LlamaPolicy(Policy):
|
||||
layers_per_stage = self.distribute_layers(
|
||||
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||
)
|
||||
stage_manager.stage_indices = Policy.get_stage_index(
|
||||
stage_manager.stage_indices = self.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
@@ -178,8 +178,8 @@ class LlamaPolicy(Policy):
|
||||
}
|
||||
|
||||
else:
|
||||
layers_per_stage = Policy.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages)
|
||||
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||
@@ -207,7 +207,7 @@ class LlamaPolicy(Policy):
|
||||
layers_per_stage = self.distribute_layers(
|
||||
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks
|
||||
)
|
||||
stage_indices = Policy.get_stage_index(
|
||||
stage_indices = self.get_stage_index(
|
||||
layers_per_stage,
|
||||
stage_manager.stage,
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
|
Reference in New Issue
Block a user