mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[shardformer] support whisper (#4212)
* support whisper * fix bug in vocabembedding * support downstream model of whisper * update readme
This commit is contained in:
@@ -202,7 +202,6 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.process_group = process_group
|
||||
@@ -276,6 +275,15 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _select_padding_idx(self, padding_idx: int):
|
||||
# select padding index according to the rank
|
||||
if padding_idx is None:
|
||||
return None
|
||||
elif padding_idx < self.vocab_end_index and padding_idx >= self.vocab_start_index:
|
||||
return padding_idx - self.vocab_start_index
|
||||
else:
|
||||
return None
|
||||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
# Build the mask.
|
||||
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
|
||||
|
Reference in New Issue
Block a user