[shardformer] added embedding gradient check (#4124)

This commit is contained in:
Frank Lee
2023-06-30 16:16:44 +08:00
parent 44a190e6ac
commit ae035d305d
14 changed files with 255 additions and 74 deletions

View File

@@ -1,8 +1,10 @@
import torch
import torch.distributed as dist
import torch.nn as nn
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -73,7 +75,6 @@ class BloomPolicy(Policy):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
@@ -161,13 +162,12 @@ class BloomPolicy(Policy):
def new_model_class(self):
# do nothing
return self.model
return None
def postprocess(self):
return self.model
# BertModel
class BloomModelPolicy(BloomPolicy):
pass
@@ -191,6 +191,19 @@ class BloomForCausalLMPolicy(BloomPolicy):
policy.update(new_item)
return policy
def postprocess(self):
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
if not isinstance(param, nn.Parameter):
param = nn.Parameter(param)
# tie weights
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
class BloomForSequenceClassificationPolicy(BloomPolicy):