mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[shardformer] support vision transformer (#4096)
* first v of vit shardformer * keep vit * update * vit shard add vitattention vitlayer * update num head shard para * finish test for vit * add new_model_class & postprocess * add vit readme * delete old files & fix the conflict * fix sth
This commit is contained in:
@@ -91,7 +91,7 @@ We will follow this roadmap to develop Shardformer:
|
||||
- [ ] GPT Neo
|
||||
- [ ] GPT-J
|
||||
- [ ] CV
|
||||
- [ ] ViT
|
||||
- [x] ViT
|
||||
- [ ] BEiT
|
||||
- [ ] SwinTransformer
|
||||
- [ ] SwinTransformer V2
|
||||
|
@@ -287,4 +287,4 @@ def reduce_forward(input_, process_group):
|
||||
|
||||
|
||||
def reduce_backward(input_, process_group):
|
||||
return _ReduceBackward.apply(input_, process_group)
|
||||
return _ReduceBackward.apply(input_, process_group)
|
@@ -61,4 +61,4 @@ class FusedLayerNorm():
|
||||
# copy weight and bias
|
||||
layernorm.weight.copy_(module.weight)
|
||||
layernorm.bias.copy_(module.bias)
|
||||
return layernorm
|
||||
return layernorm
|
@@ -316,4 +316,4 @@ class BertForMultipleChoicePolicy(BertPolicy):
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
return module_policy
|
@@ -167,4 +167,4 @@ class T5ForConditionalGenerationPolicy(T5ModelPolicy):
|
||||
|
||||
|
||||
class T5EncoderPolicy(T5ModelPolicy):
|
||||
pass
|
||||
pass
|
96
colossalai/shardformer/policies/vit.py
Normal file
96
colossalai/shardformer/policies/vit.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention
|
||||
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, Dropout1D
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
class ViTPolicy(Policy):
|
||||
|
||||
def preprocess(self):
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
return {
|
||||
ViTEmbeddings:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement{},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=Dropout1D,
|
||||
)
|
||||
]
|
||||
),
|
||||
ViTLayer:
|
||||
ModulePolicyDescription(
|
||||
attribute_replacement{
|
||||
"attention.attention.num_attention_heads":
|
||||
self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size,
|
||||
"attention.attention.all_head_size":
|
||||
self.model.config.hidden_size//self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.query",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.key",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.value",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.attention.dropout",
|
||||
target_module=Dropout1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dense",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attention.output.dropout",
|
||||
target_module=Dropout1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="intermediate.dense",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dense",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="output.dropout",
|
||||
target_module=Dropout1D,
|
||||
),
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user