[example] Palm adding gemini, still has bugs (#2221)

This commit is contained in:
ZijianYY
2022-12-29 14:01:09 +08:00
committed by GitHub
parent 7010e18134
commit 63cc77173b
4 changed files with 82 additions and 8 deletions

View File

@@ -47,7 +47,9 @@ class RotaryEmbedding(nn.Module):
def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device)
#freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
#freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
i, j = len(seq.type_as(self.inv_freq)), len(self.inv_freq)
freqs = matmul(seq.type_as(self.inv_freq).reshape(i, 1), self.inv_freq.reshape(1, j))
return torch.cat((freqs, freqs), dim=-1)