[shardformer] support whisper (#4212)

* support whisper

* fix bug in vocabembedding

* support downstream model of whisper

* update readme
This commit is contained in:
FoolPlayer
2023-07-17 14:25:32 +08:00
committed by Hongxin Liu
parent dd2bf02679
commit 9ee4ebea83
7 changed files with 443 additions and 2 deletions

View File

@@ -202,7 +202,6 @@ class VocabParallelEmbedding1D(ParallelModule):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.process_group = process_group
@@ -276,6 +275,15 @@ class VocabParallelEmbedding1D(ParallelModule):
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _select_padding_idx(self, padding_idx: int):
# select padding index according to the rank
if padding_idx is None:
return None
elif padding_idx < self.vocab_end_index and padding_idx >= self.vocab_start_index:
return padding_idx - self.vocab_start_index
else:
return None
def forward(self, input_: Tensor) -> Tensor:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)