[pipeline] move bert related pipeline components to shardformer (#4187)

* move bert related pipeline components to shardformer

* fix bugs

* revision

* fix bert model tests

* fix bert_lm_head model tests

* fix tests

* fix tests

* done checks

* skip bloom
This commit is contained in:
Jianghai
2023-07-07 15:41:00 +08:00
committed by Hongxin Liu
parent c5ea728016
commit f3bcc292c8
9 changed files with 556 additions and 65 deletions

View File

@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
@@ -176,3 +177,33 @@ class Policy(ABC):
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
"""
return []
@staticmethod
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
"""Divide layers into stages
"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages
# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
@staticmethod
def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]:
"""
get the start index and end index of layers for each stage.
"""
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
start_idx = num_layers_per_stage_accumulated[stage]
end_idx = num_layers_per_stage_accumulated[stage + 1]
return [start_idx, end_idx]