[pipeline] move bert related pipeline components to shardformer (#4187)

* move bert related pipeline components to shardformer

* fix bugs

* revision

* fix bert model tests

* fix bert_lm_head model tests

* fix tests

* fix tests

* done checks

* skip bloom
This commit is contained in:
Jianghai
2023-07-07 15:41:00 +08:00
committed by Hongxin Liu
parent c5ea728016
commit f3bcc292c8
9 changed files with 556 additions and 65 deletions

View File

@@ -109,33 +109,3 @@ class Policy:
self.replace_forward(module)
shared_params = self.get_shared_params(module)
return hold_params, hold_buffers, shared_params
@staticmethod
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
"""
divide layers into stages
"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages
# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
@staticmethod
def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]:
"""
get the start index and end index of layers for each stage.
"""
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
start_idx = num_layers_per_stage_accumulated[stage]
end_idx = num_layers_per_stage_accumulated[stage + 1]
return [start_idx, end_idx]