mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-25 03:31:56 +00:00
[shardformer] added embedding gradient check (#4124)
This commit is contained in:
@@ -55,7 +55,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
|
||||
except AttributeError:
|
||||
if ignore:
|
||||
return
|
||||
raise AttributeError(f"Object {obj} has no attribute {attr}")
|
||||
raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}")
|
||||
setattr(obj, attrs[-1], value)
|
||||
|
||||
|
||||
@@ -76,5 +76,5 @@ def getattr_(obj, attr: str, ignore: bool = False):
|
||||
except AttributeError:
|
||||
if ignore:
|
||||
return None
|
||||
raise AttributeError(f"Object {obj} has no attribute {attr}")
|
||||
raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}")
|
||||
return obj
|
||||
|
@@ -97,7 +97,7 @@ class BertPolicy(Policy):
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
)
|
||||
])
|
||||
}
|
||||
|
@@ -1,8 +1,10 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
@@ -73,7 +75,6 @@ class BloomPolicy(Policy):
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
# TODO:
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
@@ -161,13 +162,12 @@ class BloomPolicy(Policy):
|
||||
|
||||
def new_model_class(self):
|
||||
# do nothing
|
||||
return self.model
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
# BertModel
|
||||
class BloomModelPolicy(BloomPolicy):
|
||||
pass
|
||||
|
||||
@@ -191,6 +191,19 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
|
||||
if not isinstance(param, nn.Parameter):
|
||||
param = nn.Parameter(param)
|
||||
|
||||
# tie weights
|
||||
setattr_(self.model, k, param)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
class BloomForSequenceClassificationPolicy(BloomPolicy):
|
||||
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
@@ -35,7 +36,7 @@ class OPTPolicy(Policy):
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=Embedding1D,
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
]),
|
||||
OPTDecoderLayer:
|
||||
@@ -127,6 +128,18 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = {
|
||||
'model.decoder.embed_tokens': 'lm_head',
|
||||
}
|
||||
|
||||
for k, v in binding_map.items():
|
||||
src_mod = getattr_(self.model, k)
|
||||
dst_mod = getattr_(self.model, v)
|
||||
dst_mod.weight = src_mod.weight
|
||||
|
||||
return self.model
|
||||
|
||||
|
||||
class OPTForSequenceClassificationPolicy(OPTPolicy):
|
||||
|
||||
|
@@ -1,11 +1,20 @@
|
||||
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import (
|
||||
DropoutForParallelInput,
|
||||
Embedding1D,
|
||||
FusedRMSNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
VocabParallelEmbedding1D,
|
||||
)
|
||||
from colossalai.shardformer.policies.basepolicy import ModulePolicyDescription
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
|
||||
|
||||
|
||||
class T5ModelPolicy(Policy):
|
||||
class T5BasePolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
@@ -33,7 +42,7 @@ class T5ModelPolicy(Policy):
|
||||
T5Stack,
|
||||
)
|
||||
|
||||
return {
|
||||
base_policy = {
|
||||
T5Stack:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
@@ -41,6 +50,10 @@ class T5ModelPolicy(Policy):
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=Embedding1D,
|
||||
)
|
||||
]),
|
||||
T5LayerSelfAttention:
|
||||
@@ -158,30 +171,86 @@ class T5ModelPolicy(Policy):
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
|
||||
|
||||
for k, v in binding_map:
|
||||
mod = getattr_(self.model, k)
|
||||
setattr_(self.model, v, mod)
|
||||
return self.model
|
||||
|
||||
|
||||
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
|
||||
class T5ModelPolicy(T5BasePolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5Model
|
||||
|
||||
base_policy = super().module_policy()
|
||||
base_policy[T5Model] = ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
return base_policy
|
||||
|
||||
|
||||
class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5ForConditionalGeneration
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
new_item = {
|
||||
T5ForConditionalGeneration:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
|
||||
policy.update(new_item)
|
||||
policy[T5ForConditionalGeneration] = ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
super().postprocess()
|
||||
|
||||
class T5EncoderPolicy(T5ModelPolicy):
|
||||
pass
|
||||
binding_map = {"shared": "lm_head"}
|
||||
|
||||
for k, v in binding_map.items():
|
||||
src_mod = getattr_(self.model, k)
|
||||
dst_mod = getattr_(self.model, v)
|
||||
dst_mod.weight = src_mod.weight
|
||||
|
||||
return self.model
|
||||
|
||||
|
||||
class T5EncoderPolicy(T5BasePolicy):
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
base_policy = super().module_policy()
|
||||
base_policy[T5EncoderModel] = ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
return base_policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = [
|
||||
["shared", "encoder.embed_tokens"],
|
||||
]
|
||||
|
||||
for k, v in binding_map:
|
||||
mod = getattr_(self.model, k)
|
||||
setattr_(self.model, v, mod)
|
||||
return self.model
|
||||
|
@@ -38,17 +38,6 @@ class ModelSharder(object):
|
||||
self._replace_module()
|
||||
self._postprocess()
|
||||
|
||||
def reshape_embedding(self) -> None:
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
vocab_size = self.model_config.vocab_size
|
||||
world_size = self.shard_config.world_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.model_config = self.model.config
|
||||
|
||||
def _preprocess(self) -> None:
|
||||
self.model = self.policy.preprocess()
|
||||
|
||||
|
Reference in New Issue
Block a user