mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 03:52:01 +00:00
[shardformer] fix pipeline forward error if custom layer distribution is used (#5189)
* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
@@ -10,9 +10,12 @@ def test_t5_pipeline_distribution():
|
||||
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
|
||||
}
|
||||
|
||||
policy = T5BasePolicy()
|
||||
for i in range(num_test_cases):
|
||||
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i]
|
||||
_, decoder_starting_stage = policy.distribute_t5_layers(
|
||||
test_dict["num_encoder_layers"][i],
|
||||
test_dict["num_decoder_layers"][i],
|
||||
test_dict["num_stages"][i],
|
||||
)
|
||||
assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage
|
||||
|
||||
@@ -32,14 +35,15 @@ def test_t5_pipeline_layers():
|
||||
}
|
||||
|
||||
for i in range(num_test_cases):
|
||||
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
|
||||
test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i]
|
||||
policy = T5BasePolicy()
|
||||
layers_per_stage, decoder_starting_stage = policy.distribute_t5_layers(
|
||||
test_dict["num_encoder_layers"][i],
|
||||
test_dict["num_decoder_layers"][i],
|
||||
test_dict["num_stages"][i],
|
||||
)
|
||||
|
||||
for stage in range(test_dict["num_stages"][i]):
|
||||
start_idx, end_idx = test_dict["layers_per_stage"][i][stage]
|
||||
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(
|
||||
layers_per_stage, stage, decoder_starting_stage
|
||||
)
|
||||
predicted_start, predicted_end = policy.get_t5_stage_index(layers_per_stage, stage, decoder_starting_stage)
|
||||
assert start_idx == predicted_start
|
||||
assert end_idx == predicted_end
|
||||
|
@@ -10,9 +10,12 @@ def test_whisper_pipeline_distribution():
|
||||
"decoder_starting_stage": [1, 1, 2, 2, 3, 1, 5, 2],
|
||||
}
|
||||
|
||||
policy = WhisperPolicy()
|
||||
for i in range(num_test_cases):
|
||||
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
|
||||
test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i]
|
||||
_, decoder_starting_stage = policy.distribute_whisper_layers(
|
||||
test_dict["num_encoder_layers"][i],
|
||||
test_dict["num_decoder_layers"][i],
|
||||
test_dict["num_stages"][i],
|
||||
)
|
||||
assert test_dict["decoder_starting_stage"][i] == decoder_starting_stage
|
||||
|
||||
@@ -31,14 +34,17 @@ def test_whisper_pipeline_layers():
|
||||
],
|
||||
}
|
||||
|
||||
policy = WhisperPolicy()
|
||||
for i in range(num_test_cases):
|
||||
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
|
||||
test_dict["num_encoder_layers"][i], test_dict["num_decoder_layers"][i], test_dict["num_stages"][i]
|
||||
layers_per_stage, decoder_starting_stage = policy.distribute_whisper_layers(
|
||||
test_dict["num_encoder_layers"][i],
|
||||
test_dict["num_decoder_layers"][i],
|
||||
test_dict["num_stages"][i],
|
||||
)
|
||||
|
||||
for stage in range(test_dict["num_stages"][i]):
|
||||
start_idx, end_idx = test_dict["layers_per_stage"][i][stage]
|
||||
predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(
|
||||
predicted_start, predicted_end = policy.get_whisper_stage_index(
|
||||
layers_per_stage, stage, decoder_starting_stage
|
||||
)
|
||||
assert start_idx == predicted_start
|
||||
|
Reference in New Issue
Block a user