mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[shardformer]: Feature/shardformer, add some docstring and readme (#3816)
* init shardformer code structure * add implement of sharder (inject and replace) * add implement of replace layer to colossal layer * separate different layer policy, add some notion * implement 1d and 2d slicer, can tell col or row * fix bug when slicing and inject model * fix some bug; add inference test example * add share weight and train example * add train * add docstring and readme * add docstring for other files * pre-commit
This commit is contained in:
@@ -1,56 +1,57 @@
|
||||
from typing import Dict, List, Tuple, Type, Any, Callable
|
||||
import torch.nn as nn
|
||||
from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer
|
||||
import colossalai.nn as col_nn
|
||||
from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings, BertLMPredictionHead
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
|
||||
|
||||
import colossalai.nn as col_nn
|
||||
|
||||
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
|
||||
|
||||
|
||||
class BertPolicy(Policy):
|
||||
|
||||
@staticmethod
|
||||
def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]:
|
||||
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
|
||||
return {
|
||||
BertLayer: Argument(
|
||||
attr_dict = {
|
||||
# 1. shard hidden size
|
||||
"attention.self.all_head_size": config.hidden_size // world_size,
|
||||
"crossattention.self.all_head_size": config.hidden_size // world_size,
|
||||
# 2. shard number of heads
|
||||
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||
|
||||
},
|
||||
param_funcs = [
|
||||
BertPolicy.attn_in,
|
||||
BertPolicy.attn_out,
|
||||
BertPolicy.mlp_in,
|
||||
BertPolicy.mlp_out
|
||||
]
|
||||
),
|
||||
BertEmbeddings: Argument(
|
||||
attr_dict = {
|
||||
# 1. shard vocab size
|
||||
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||
"word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size,
|
||||
},
|
||||
param_funcs = [
|
||||
BertPolicy.embedding,
|
||||
],
|
||||
binding_layers = [
|
||||
BertLMPredictionHead,
|
||||
]
|
||||
),
|
||||
BertLMPredictionHead: Argument(
|
||||
attr_dict = {
|
||||
# 1. shard vocab size
|
||||
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||
},
|
||||
param_funcs = [
|
||||
BertPolicy.unembedding,
|
||||
]
|
||||
)
|
||||
BertLayer:
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. shard hidden size
|
||||
"attention.self.all_head_size": config.hidden_size // world_size,
|
||||
"crossattention.self.all_head_size": config.hidden_size // world_size,
|
||||
# 2. shard number of heads
|
||||
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
|
||||
},
|
||||
param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]),
|
||||
BertEmbeddings:
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. shard vocab size
|
||||
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||
"word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size,
|
||||
},
|
||||
param_funcs=[
|
||||
BertPolicy.embedding,
|
||||
]),
|
||||
BertLMPredictionHead:
|
||||
Argument(
|
||||
attr_dict={
|
||||
# 1. shard vocab size
|
||||
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
|
||||
# 2. add the size of the sliced embedding layer excluding the last slice
|
||||
},
|
||||
param_funcs=[
|
||||
BertPolicy.unembedding,
|
||||
])
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def binding_policy() -> Dict:
|
||||
return {
|
||||
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -89,9 +90,8 @@ class BertPolicy(Policy):
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
ignore=True,
|
||||
),
|
||||
|
||||
]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def attn_out() -> List:
|
||||
return [
|
||||
@@ -107,17 +107,17 @@ class BertPolicy(Policy):
|
||||
ignore=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mlp_in() -> List:
|
||||
return [
|
||||
Col_Layer(
|
||||
Col_Layer(
|
||||
weight="intermediate.dense.weight",
|
||||
bias="intermediate.dense.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def mlp_out() -> List:
|
||||
return [
|
||||
@@ -130,13 +130,11 @@ class BertPolicy(Policy):
|
||||
|
||||
@staticmethod
|
||||
def embedding() -> List:
|
||||
return [
|
||||
Col_Layer(
|
||||
weight="word_embeddings.weight",
|
||||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||
)
|
||||
]
|
||||
|
||||
return [Col_Layer(
|
||||
weight="word_embeddings.weight",
|
||||
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||
)]
|
||||
|
||||
@staticmethod
|
||||
def unembedding() -> List:
|
||||
return [
|
||||
@@ -148,16 +146,21 @@ class BertPolicy(Policy):
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
|
||||
|
||||
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def inject_policy() -> Tuple[nn.Module, nn.Module]:
|
||||
return (BertForMaskedLM, BertForMaskedLM_)
|
||||
|
||||
|
||||
|
||||
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
@staticmethod
|
||||
def inject_policy() -> Dict:
|
||||
return {}
|
||||
@@ -165,4 +168,4 @@ class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
# model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||
# _ = BertForMaskedLMPolicy(model)
|
||||
# print(isinstance(model,list(_.inject_policy().keys())[0]))
|
||||
# print(isinstance(model,list(_.inject_policy().keys())[0]))
|
||||
|
Reference in New Issue
Block a user