[shardformer]: Feature/shardformer, add some docstring and readme (#3816)

* init shardformer code structure

* add implement of sharder (inject and replace)

* add implement of replace layer to colossal layer

* separate different layer policy, add some notion

* implement 1d and 2d slicer, can tell col or row

* fix bug when slicing and inject model

* fix some bug; add inference test example

* add share weight and train example

* add train

* add docstring and readme

* add docstring for other files

* pre-commit
This commit is contained in:
FoolPlayer
2023-05-24 10:26:46 +08:00
committed by Frank Lee
parent 8d68de767d
commit 8cc11235c0
14 changed files with 616 additions and 343 deletions

View File

@@ -1,56 +1,57 @@
from typing import Dict, List, Tuple, Type, Any, Callable
import torch.nn as nn
from .basepolicy import Policy, Layer, Argument, Col_Layer, Row_Layer
import colossalai.nn as col_nn
from transformers.models.bert.modeling_bert import BertLayer, BertEmbeddings, BertLMPredictionHead
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
import colossalai.nn as col_nn
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
class BertPolicy(Policy):
@staticmethod
def argument_policy(config, world_size: int) -> Dict[nn.Module,Argument]:
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
return {
BertLayer: Argument(
attr_dict = {
# 1. shard hidden size
"attention.self.all_head_size": config.hidden_size // world_size,
"crossattention.self.all_head_size": config.hidden_size // world_size,
# 2. shard number of heads
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
},
param_funcs = [
BertPolicy.attn_in,
BertPolicy.attn_out,
BertPolicy.mlp_in,
BertPolicy.mlp_out
]
),
BertEmbeddings: Argument(
attr_dict = {
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
"word_embeddings.dim_size": (config.vocab_size+world_size-1) // world_size,
},
param_funcs = [
BertPolicy.embedding,
],
binding_layers = [
BertLMPredictionHead,
]
),
BertLMPredictionHead: Argument(
attr_dict = {
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
},
param_funcs = [
BertPolicy.unembedding,
]
)
BertLayer:
Argument(
attr_dict={
# 1. shard hidden size
"attention.self.all_head_size": config.hidden_size // world_size,
"crossattention.self.all_head_size": config.hidden_size // world_size,
# 2. shard number of heads
"attention.self.num_attention_heads": config.num_attention_heads // world_size,
"crossattention.self.num_attention_heads": config.num_attention_heads // world_size,
},
param_funcs=[BertPolicy.attn_in, BertPolicy.attn_out, BertPolicy.mlp_in, BertPolicy.mlp_out]),
BertEmbeddings:
Argument(
attr_dict={
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
"word_embeddings.dim_size": (config.vocab_size + world_size - 1) // world_size,
},
param_funcs=[
BertPolicy.embedding,
]),
BertLMPredictionHead:
Argument(
attr_dict={
# 1. shard vocab size
# "word_embeddings.num_embeddings": config.vocab_size // world_size,
# 2. add the size of the sliced embedding layer excluding the last slice
},
param_funcs=[
BertPolicy.unembedding,
])
}
@staticmethod
def binding_policy() -> Dict:
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
@staticmethod
@@ -89,9 +90,8 @@ class BertPolicy(Policy):
replace_layer=col_nn.Linear1D_Col,
ignore=True,
),
]
@staticmethod
def attn_out() -> List:
return [
@@ -107,17 +107,17 @@ class BertPolicy(Policy):
ignore=True,
),
]
@staticmethod
def mlp_in() -> List:
return [
Col_Layer(
Col_Layer(
weight="intermediate.dense.weight",
bias="intermediate.dense.bias",
replace_layer=col_nn.Linear1D_Col,
),
]
@staticmethod
def mlp_out() -> List:
return [
@@ -130,13 +130,11 @@ class BertPolicy(Policy):
@staticmethod
def embedding() -> List:
return [
Col_Layer(
weight="word_embeddings.weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)
]
return [Col_Layer(
weight="word_embeddings.weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)]
@staticmethod
def unembedding() -> List:
return [
@@ -148,16 +146,21 @@ class BertPolicy(Policy):
)
]
from transformers import BertForMaskedLM
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
class BertForMaskedLMPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
return (BertForMaskedLM, BertForMaskedLM_)
class BertForSequenceClassificationPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Dict:
return {}
@@ -165,4 +168,4 @@ class BertForSequenceClassificationPolicy(BertPolicy):
# model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# _ = BertForMaskedLMPolicy(model)
# print(isinstance(model,list(_.inject_policy().keys())[0]))
# print(isinstance(model,list(_.inject_policy().keys())[0]))