mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 12:43:02 +00:00
[shardformer] added embedding gradient check (#4124)
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
@@ -73,7 +75,6 @@ class BloomPolicy(Policy):
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
# TODO:
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
@@ -161,13 +162,12 @@ class BloomPolicy(Policy):
|
||||
|
||||
def new_model_class(self):
|
||||
# do nothing
|
||||
return self.model
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
# BertModel
|
||||
class BloomModelPolicy(BloomPolicy):
|
||||
pass
|
||||
|
||||
@@ -191,6 +191,19 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
|
||||
if not isinstance(param, nn.Parameter):
|
||||
param = nn.Parameter(param)
|
||||
|
||||
# tie weights
|
||||
setattr_(self.model, k, param)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
class BloomForSequenceClassificationPolicy(BloomPolicy):
|
||||
|
||||
|
Reference in New Issue
Block a user