[shardformer] support inplace sharding (#4251)

* [shardformer] embedding support inplace sharding

* [shardformer] linear support inplace sharding

* [shardformer] layernorm support inplace sharding

* [shardformer] qkv support inplace sharding

* [test] update shardformer layer test

* [shardformer] fix shared param sharding

* [shardformer] fix bert policy

* [shardformer] fix bloom policy

* [shardformer] fix llama policy

* [shardformer] fix opt policy

* [shardformer] fix t5 policy

* [shardformer] fix fused qkv linear

* [shardformer] fix bugs

* force sync

* [test] fix bugs

* [test] fix transformer version
This commit is contained in:
Hongxin Liu
2023-07-20 10:39:06 +08:00
parent 2a2eacfaf1
commit d921ce8391
26 changed files with 371 additions and 340 deletions

View File

@@ -1,9 +1,7 @@
import warnings
from functools import partial
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
@@ -27,7 +25,6 @@ from transformers.utils import logging
import colossalai.shardformer.layer as col_nn
from colossalai.pipeline.stage_manager import PipelineStageManager
from .._utils import getattr_, setattr_
from ..modeling.bloom import build_bloom_alibi_tensor_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -229,20 +226,10 @@ class BloomForCausalLMPolicy(BloomPolicy):
# tie weights
return [{
0: bloom_model.transformer.word_embeddings.weight,
self.stage_manager.num_stages - 1: bloom_model.lm_head.weight
self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight
}]
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
# tie weights
setattr_(self.model, v, param)
return self.model
class BloomForSequenceClassificationPolicy(BloomPolicy):
@@ -692,7 +679,7 @@ def bloom_for_sequence_classification_forward(
all_cross_attentions = None
if stage_manager.is_last_stage():
batch_size = hidden_states.shape[0]
#update batch size
# update batch size
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)