[shardformer] support vision transformer (#4096)

* first v of vit shardformer

* keep vit

* update

* vit shard add vitattention vitlayer

* update num head shard para

* finish test for vit

* add new_model_class & postprocess

* add vit readme

* delete old files & fix the conflict

* fix sth
This commit is contained in:
Kun Lin
2023-06-28 13:28:18 +08:00
committed by Frank Lee
parent ac80937138
commit 8af29ee47a
10 changed files with 159 additions and 8 deletions

View File

@@ -91,7 +91,7 @@ We will follow this roadmap to develop Shardformer:
- [ ] GPT Neo
- [ ] GPT-J
- [ ] CV
- [ ] ViT
- [x] ViT
- [ ] BEiT
- [ ] SwinTransformer
- [ ] SwinTransformer V2

View File

@@ -287,4 +287,4 @@ def reduce_forward(input_, process_group):
def reduce_backward(input_, process_group):
return _ReduceBackward.apply(input_, process_group)
return _ReduceBackward.apply(input_, process_group)

View File

@@ -61,4 +61,4 @@ class FusedLayerNorm():
# copy weight and bias
layernorm.weight.copy_(module.weight)
layernorm.bias.copy_(module.bias)
return layernorm
return layernorm

View File

@@ -316,4 +316,4 @@ class BertForMultipleChoicePolicy(BertPolicy):
])
}
module_policy.update(addon_module)
return module_policy
return module_policy

View File

@@ -167,4 +167,4 @@ class T5ForConditionalGenerationPolicy(T5ModelPolicy):
class T5EncoderPolicy(T5ModelPolicy):
pass
pass

View File

@@ -0,0 +1,96 @@
from typing import Dict, Union
import torch.nn as nn
from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, Dropout1D
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class ViTPolicy(Policy):
def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return {
ViTEmbeddings:
ModulePolicyDescription(
attribute_replacement{},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
)
]
),
ViTLayer:
ModulePolicyDescription(
attribute_replacement{
"attention.attention.num_attention_heads":
self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size":
self.model.config.hidden_size//self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=Dropout1D,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=Dropout1D,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=Dropout1D,
),
]
),
}
def new_model_class(self):
return None
def postprocess(self):
return self.model