mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-01 23:30:10 +00:00
[example] Palm adding gemini, still has bugs (#2221)
This commit is contained in:
@@ -47,7 +47,9 @@ class RotaryEmbedding(nn.Module):
|
||||
def forward(self, max_seq_len, *, device):
|
||||
seq = torch.arange(max_seq_len, device=device)
|
||||
#freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
|
||||
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
||||
#freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
||||
i, j = len(seq.type_as(self.inv_freq)), len(self.inv_freq)
|
||||
freqs = matmul(seq.type_as(self.inv_freq).reshape(i, 1), self.inv_freq.reshape(1, j))
|
||||
return torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user