Fixup bert python bindings.

This commit is contained in:
Adam Treat
2023-07-13 17:57:48 -04:00
committed by AT
parent 6200900677
commit ee4186d579
5 changed files with 37 additions and 23 deletions

View File

@@ -14,6 +14,7 @@
#include <regex>
#include <thread>
#include <algorithm>
#include <numeric>
//#define DEBUG_BERT
@@ -462,11 +463,6 @@ void bert_eval(
ggml_set_f32(sum, 1.0f / N);
inpL = ggml_mul_mat(ctx0, inpL, sum);
// normalizer
ggml_tensor *length = ggml_sqrt(ctx0,
ggml_sum(ctx0, ggml_sqr(ctx0, inpL)));
inpL = ggml_scale(ctx0, inpL, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length));
ggml_tensor *output = inpL;
// run the computation
ggml_build_forward_expand(&gf, output);
@@ -987,6 +983,9 @@ std::vector<float> Bert::embedding(const std::string &text)
}
std::transform(embeddingsSum.begin(), embeddingsSum.end(), embeddingsSum.begin(), [embeddingsSumTotal](float num){ return num / embeddingsSumTotal; });
double magnitude = std::sqrt(std::inner_product(embeddingsSum.begin(), embeddingsSum.end(), embeddingsSum.begin(), 0.0));
for (auto &value : embeddingsSum)
value /= magnitude;
std::vector<float> finalEmbeddings(embeddingsSum.begin(), embeddingsSum.end());
return finalEmbeddings;
}