mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[shardformer] support inplace sharding (#4251)
* [shardformer] embedding support inplace sharding * [shardformer] linear support inplace sharding * [shardformer] layernorm support inplace sharding * [shardformer] qkv support inplace sharding * [test] update shardformer layer test * [shardformer] fix shared param sharding * [shardformer] fix bert policy * [shardformer] fix bloom policy * [shardformer] fix llama policy * [shardformer] fix opt policy * [shardformer] fix t5 policy * [shardformer] fix fused qkv linear * [shardformer] fix bugs * force sync * [test] fix bugs * [test] fix transformer version
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
import warnings
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
@@ -27,7 +25,6 @@ from transformers.utils import logging
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.bloom import build_bloom_alibi_tensor_fn
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
@@ -229,20 +226,10 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
||||
# tie weights
|
||||
return [{
|
||||
0: bloom_model.transformer.word_embeddings.weight,
|
||||
self.stage_manager.num_stages - 1: bloom_model.lm_head.weight
|
||||
self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
|
||||
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
# tie weights
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
class BloomForSequenceClassificationPolicy(BloomPolicy):
|
||||
|
||||
@@ -692,7 +679,7 @@ def bloom_for_sequence_classification_forward(
|
||||
all_cross_attentions = None
|
||||
if stage_manager.is_last_stage():
|
||||
batch_size = hidden_states.shape[0]
|
||||
#update batch size
|
||||
# update batch size
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
|
Reference in New Issue
Block a user