mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-07 19:40:21 +00:00
implement local Nomic Embed via llama.cpp (#2086)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <initializer_list>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
@@ -30,6 +31,19 @@ static constexpr int GGUF_VER_MAX = 3;
|
||||
|
||||
static const char * const modelType_ = "LLaMA";
|
||||
|
||||
static const std::vector<const char *> KNOWN_ARCHES {
|
||||
"baichuan", "bert", "bloom", "codeshell", "falcon", "gemma", "gpt2", "llama", "mpt", "nomic-bert", "orion",
|
||||
"persimmon", "phi2", "plamo", "qwen", "qwen2", "refact", "stablelm", "starcoder"
|
||||
};
|
||||
|
||||
static const std::vector<const char *> EMBEDDING_ARCHES {
|
||||
"bert", "nomic-bert"
|
||||
};
|
||||
|
||||
static bool is_embedding_arch(const std::string &arch) {
|
||||
return std::find(EMBEDDING_ARCHES.begin(), EMBEDDING_ARCHES.end(), arch) < EMBEDDING_ARCHES.end();
|
||||
}
|
||||
|
||||
static bool llama_verbose() {
|
||||
const char* var = getenv("GPT4ALL_VERBOSE_LLAMACPP");
|
||||
return var && *var;
|
||||
@@ -124,7 +138,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const
|
||||
auto * ctx = load_gguf(modelPath.c_str());
|
||||
if (!ctx)
|
||||
return -1;
|
||||
auto arch = get_arch_name(ctx);
|
||||
std::string arch = get_arch_name(ctx);
|
||||
|
||||
int32_t value = -1;
|
||||
if (ctx) {
|
||||
@@ -193,7 +207,7 @@ size_t LLamaModel::requiredMem(const std::string &modelPath, int n_ctx, int ngl)
|
||||
return filesize + est_kvcache_size;
|
||||
}
|
||||
|
||||
bool LLamaModel::isModelBlacklisted(const std::string &modelPath) {
|
||||
bool LLamaModel::isModelBlacklisted(const std::string &modelPath) const {
|
||||
auto * ctx = load_gguf(modelPath.c_str());
|
||||
if (!ctx) {
|
||||
std::cerr << __func__ << ": failed to load " << modelPath << "\n";
|
||||
@@ -229,6 +243,18 @@ bool LLamaModel::isModelBlacklisted(const std::string &modelPath) {
|
||||
return res;
|
||||
}
|
||||
|
||||
bool LLamaModel::isEmbeddingModel(const std::string &modelPath) const {
|
||||
auto *ctx_gguf = load_gguf(modelPath.c_str());
|
||||
if (!ctx_gguf) {
|
||||
std::cerr << __func__ << ": failed to load GGUF from " << modelPath << "\n";
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string arch = get_arch_name(ctx_gguf);
|
||||
gguf_free(ctx_gguf);
|
||||
return is_embedding_arch(arch);
|
||||
}
|
||||
|
||||
bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
||||
{
|
||||
d_ptr->modelLoaded = false;
|
||||
@@ -287,20 +313,25 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
||||
if (!d_ptr->model) {
|
||||
fflush(stdout);
|
||||
d_ptr->device = -1;
|
||||
std::cerr << "LLAMA ERROR: failed to load model from " << modelPath << std::endl;
|
||||
std::cerr << "LLAMA ERROR: failed to load model from " << modelPath << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
const int n_ctx_train = llama_n_ctx_train(d_ptr->model);
|
||||
if (n_ctx > n_ctx_train) {
|
||||
std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("
|
||||
<< n_ctx << " specified)\n";
|
||||
}
|
||||
|
||||
// -- initialize the context --
|
||||
|
||||
d_ptr->ctx_params = llama_context_default_params();
|
||||
|
||||
bool isEmbedding = is_embedding_arch(llama_model_arch(d_ptr->model));
|
||||
const int n_ctx_train = llama_n_ctx_train(d_ptr->model);
|
||||
if (isEmbedding) {
|
||||
d_ptr->ctx_params.n_batch = n_ctx_train;
|
||||
} else {
|
||||
if (n_ctx > n_ctx_train) {
|
||||
std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("
|
||||
<< n_ctx << " specified)\n";
|
||||
}
|
||||
}
|
||||
|
||||
d_ptr->ctx_params.n_ctx = n_ctx;
|
||||
d_ptr->ctx_params.seed = params.seed;
|
||||
d_ptr->ctx_params.type_k = params.kv_type;
|
||||
@@ -314,6 +345,9 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
||||
d_ptr->ctx_params.n_threads = d_ptr->n_threads;
|
||||
d_ptr->ctx_params.n_threads_batch = d_ptr->n_threads;
|
||||
|
||||
if (m_supportsEmbedding)
|
||||
d_ptr->ctx_params.embeddings = true;
|
||||
|
||||
d_ptr->ctx = llama_new_context_with_model(d_ptr->model, d_ptr->ctx_params);
|
||||
if (!d_ptr->ctx) {
|
||||
fflush(stdout);
|
||||
@@ -332,6 +366,9 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
||||
}
|
||||
#endif
|
||||
|
||||
m_supportsEmbedding = isEmbedding;
|
||||
m_supportsCompletion = !isEmbedding;
|
||||
|
||||
fflush(stdout);
|
||||
d_ptr->modelLoaded = true;
|
||||
return true;
|
||||
@@ -535,6 +572,320 @@ bool LLamaModel::usingGPUDevice()
|
||||
#endif
|
||||
}
|
||||
|
||||
void llama_batch_add(
|
||||
struct llama_batch & batch,
|
||||
llama_token id,
|
||||
llama_pos pos,
|
||||
const std::vector<llama_seq_id> & seq_ids,
|
||||
bool logits) {
|
||||
batch.token [batch.n_tokens] = id;
|
||||
batch.pos [batch.n_tokens] = pos;
|
||||
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
|
||||
for (size_t i = 0; i < seq_ids.size(); ++i) {
|
||||
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
|
||||
}
|
||||
batch.logits [batch.n_tokens] = logits;
|
||||
|
||||
batch.n_tokens++;
|
||||
}
|
||||
|
||||
static void batch_add_seq(llama_batch &batch, const std::vector<LLModel::Token> &tokens, int seq_id) {
|
||||
for (unsigned i = 0; i < tokens.size(); i++) {
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
size_t LLamaModel::embeddingSize() const {
|
||||
return llama_n_embd(d_ptr->model);
|
||||
}
|
||||
|
||||
struct EmbModelSpec {
|
||||
const char *docPrefix;
|
||||
const char *queryPrefix;
|
||||
std::vector<const char *> otherPrefixes = {};
|
||||
bool matryoshkaCapable = false;
|
||||
const char *recommendedDims = nullptr;
|
||||
};
|
||||
|
||||
struct EmbModelGroup {
|
||||
EmbModelSpec spec;
|
||||
std::vector<const char *> names;
|
||||
};
|
||||
|
||||
static const EmbModelSpec NOPREFIX_SPEC {nullptr, nullptr};
|
||||
static const EmbModelSpec NOMIC_SPEC {"search_document", "search_query", {"clustering", "classification"}};
|
||||
static const EmbModelSpec E5_SPEC {"passage", "query"};
|
||||
|
||||
static const EmbModelSpec NOMIC_1_5_SPEC {
|
||||
"search_document", "search_query", {"clustering", "classification"}, true, "[768, 512, 384, 256, 128]"
|
||||
};
|
||||
static const EmbModelSpec LLM_EMBEDDER_SPEC {
|
||||
"Represent this document for retrieval",
|
||||
"Represent this query for retrieving relevant documents",
|
||||
};
|
||||
static const EmbModelSpec BGE_SPEC {
|
||||
nullptr, "Represent this sentence for searching relevant passages",
|
||||
};
|
||||
static const EmbModelSpec E5_MISTRAL_SPEC {
|
||||
nullptr, "Instruct: Given a query, retrieve relevant passages that answer the query\nQuery",
|
||||
};
|
||||
|
||||
static const EmbModelGroup EMBEDDING_MODEL_SPECS[] {
|
||||
{NOPREFIX_SPEC, {"all-MiniLM-L6-v1", "all-MiniLM-L12-v1", "all-MiniLM-L6-v2", "all-MiniLM-L12-v2"}},
|
||||
{NOMIC_SPEC, {"nomic-embed-text-v1", "nomic-embed-text-v1-ablated", "nomic-embed-text-v1-unsupervised"}},
|
||||
{NOMIC_1_5_SPEC, {"nomic-embed-text-v1.5"}},
|
||||
{LLM_EMBEDDER_SPEC, {"llm-embedder"}},
|
||||
{BGE_SPEC, {"bge-small-en", "bge-base-en", "bge-large-en",
|
||||
"bge-small-en-v1.5", "bge-base-en-v1.5", "bge-large-en-v1.5"}},
|
||||
{E5_SPEC, {"e5-small", "e5-base", "e5-large",
|
||||
"e5-small-unsupervised", "e5-base-unsupervised", "e5-large-unsupervised",
|
||||
"e5-small-v2", "e5-base-v2", "e5-large-v2"}},
|
||||
{E5_MISTRAL_SPEC, {"e5-mistral-7b-instruct",
|
||||
"multilingual-e5-small", "multilingual-e5-base", "multilingual-e5-large",
|
||||
"multilingual-e5-large-instruct"}},
|
||||
};
|
||||
|
||||
static const EmbModelSpec *getEmbedSpec(const std::string &modelName) {
|
||||
static const auto &specs = EMBEDDING_MODEL_SPECS;
|
||||
auto it = std::find_if(specs, std::end(specs),
|
||||
[&modelName](auto &spec) {
|
||||
auto &names = spec.names;
|
||||
return std::find(names.begin(), names.end(), modelName) < names.end();
|
||||
}
|
||||
);
|
||||
return it < std::end(specs) ? &it->spec : nullptr;
|
||||
}
|
||||
|
||||
void LLamaModel::embed(
|
||||
const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality, bool doMean,
|
||||
bool atlas
|
||||
) {
|
||||
const EmbModelSpec *spec;
|
||||
std::optional<std::string> prefix;
|
||||
if (d_ptr->model && (spec = getEmbedSpec(llama_model_name(d_ptr->model))))
|
||||
prefix = isRetrieval ? spec->queryPrefix : spec->docPrefix;
|
||||
|
||||
embed(texts, embeddings, prefix, dimensionality, doMean, atlas);
|
||||
}
|
||||
|
||||
void LLamaModel::embed(
|
||||
const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality,
|
||||
bool doMean, bool atlas
|
||||
) {
|
||||
if (!d_ptr->model)
|
||||
throw std::logic_error("no model is loaded");
|
||||
|
||||
const char *modelName = llama_model_name(d_ptr->model);
|
||||
if (!m_supportsEmbedding)
|
||||
throw std::logic_error("not an embedding model: "s + modelName);
|
||||
|
||||
auto *spec = getEmbedSpec(modelName);
|
||||
if (!spec)
|
||||
std::cerr << __func__ << ": warning: unknown model " << modelName << "\n";
|
||||
|
||||
const int32_t n_embd = llama_n_embd(d_ptr->model);
|
||||
if (dimensionality < 0) {
|
||||
dimensionality = n_embd;
|
||||
} else if (spec && dimensionality != n_embd) {
|
||||
auto msg = [dimensionality, modelName]() {
|
||||
return "unsupported dimensionality " + std::to_string(dimensionality) + " for model " + modelName;
|
||||
};
|
||||
if (!spec->matryoshkaCapable)
|
||||
throw std::logic_error(msg() + " (supported: " + std::to_string(n_embd) + ")");
|
||||
if (dimensionality == 0 || dimensionality > n_embd)
|
||||
throw std::logic_error(msg() + " (recommended: " + spec->recommendedDims + ")");
|
||||
}
|
||||
|
||||
if (!prefix) {
|
||||
if (spec) {
|
||||
prefix = spec->docPrefix;
|
||||
} else {
|
||||
std::cerr << __func__ << ": warning: assuming no prefix\n";
|
||||
prefix = "";
|
||||
}
|
||||
} else if (spec && prefix != spec->docPrefix && prefix != spec->queryPrefix &&
|
||||
std::find(spec->otherPrefixes.begin(), spec->otherPrefixes.end(), *prefix) == spec->otherPrefixes.end())
|
||||
{
|
||||
std::stringstream ss;
|
||||
ss << std::quoted(*prefix) << " is not a valid task type for model " << modelName;
|
||||
throw std::logic_error(ss.str());
|
||||
}
|
||||
|
||||
embedInternal(texts, embeddings, *prefix, dimensionality, doMean, atlas, spec);
|
||||
}
|
||||
|
||||
// MD5 hash of "nomic empty"
|
||||
static const char EMPTY_PLACEHOLDER[] = "24df574ea1c998de59d5be15e769658e";
|
||||
|
||||
auto product(double a) -> std::function<double(double)> {
|
||||
return [a](double b) { return a * b; };
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
double getL2NormScale(T *start, T *end) {
|
||||
double magnitude = std::sqrt(std::inner_product(start, end, start, 0.0));
|
||||
return 1.0 / std::max(magnitude, 1e-12);
|
||||
}
|
||||
|
||||
void LLamaModel::embedInternal(
|
||||
const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
|
||||
bool doMean, bool atlas, const EmbModelSpec *spec
|
||||
) {
|
||||
typedef std::vector<LLModel::Token> TokenString;
|
||||
static constexpr int32_t atlasMaxLength = 8192;
|
||||
static constexpr int chunkOverlap = 8; // Atlas overlaps n_batch-sized chunks of input by 8 tokens
|
||||
|
||||
const llama_token bos_token = llama_token_bos(d_ptr->model);
|
||||
const llama_token eos_token = llama_token_eos(d_ptr->model);
|
||||
|
||||
assert(shouldAddBOS());
|
||||
bool addEOS = llama_vocab_type(d_ptr->model) == LLAMA_VOCAB_TYPE_WPM;
|
||||
|
||||
// no EOS, optional BOS
|
||||
auto tokenize = [this, addEOS](std::string text, TokenString &tokens, bool addBOS) {
|
||||
if (!text.empty() && text[0] != ' ')
|
||||
text = ' ' + text; // normalize for SPM - our fork of llama.cpp doesn't add a space prefix
|
||||
|
||||
tokens.resize(text.length()+4);
|
||||
int32_t n_tokens = llama_tokenize(d_ptr->model, text.c_str(), text.length(), tokens.data(), tokens.size(), addBOS, false);
|
||||
assert(addEOS == (eos_token != -1 && tokens[n_tokens - 1] == eos_token));
|
||||
tokens.resize(n_tokens - addEOS); // erase EOS/SEP
|
||||
};
|
||||
|
||||
// tokenize the texts
|
||||
std::vector<TokenString> inputs;
|
||||
for (unsigned i = 0; i < texts.size(); i++) {
|
||||
auto &text = texts[i];
|
||||
auto &inp = inputs.emplace_back();
|
||||
tokenize(text, inp, false);
|
||||
if (atlas && inp.size() > atlasMaxLength) {
|
||||
if (doMean) {
|
||||
throw std::logic_error(
|
||||
"length of text at index " + std::to_string(i) + " is " + std::to_string(inp.size()) +
|
||||
" tokens which exceeds limit of " + std::to_string(atlasMaxLength)
|
||||
);
|
||||
}
|
||||
inp.resize(atlasMaxLength);
|
||||
} else if (inp.empty()) {
|
||||
if (!atlas || !text.empty()) {
|
||||
std::cerr << __func__ << ": warning: chunking tokenized text at index " << std::to_string(i)
|
||||
<< " into zero tokens\n";
|
||||
}
|
||||
tokenize(EMPTY_PLACEHOLDER, inp, false);
|
||||
}
|
||||
}
|
||||
|
||||
// tokenize the prefix
|
||||
TokenString prefixTokens;
|
||||
if (prefix.empty()) {
|
||||
prefixTokens.push_back(bos_token);
|
||||
} else {
|
||||
tokenize(prefix + ':', prefixTokens, true);
|
||||
}
|
||||
|
||||
const uint32_t n_batch = llama_n_batch(d_ptr->ctx);
|
||||
const uint32_t max_len = n_batch - (prefixTokens.size() + addEOS); // minus BOS/CLS and EOS/SEP
|
||||
if (chunkOverlap >= max_len) {
|
||||
throw std::logic_error("max chunk length of " + std::to_string(max_len) + " is smaller than overlap of " +
|
||||
std::to_string(chunkOverlap) + " tokens");
|
||||
}
|
||||
|
||||
// split into max_len-sized chunks
|
||||
struct split_batch { int idx; TokenString batch; };
|
||||
std::vector<split_batch> batches;
|
||||
for (unsigned i = 0; i < inputs.size(); i++) {
|
||||
auto &input = inputs[i];
|
||||
for (auto it = input.begin(); it < input.end(); it += max_len) {
|
||||
if (it > input.begin()) { it -= chunkOverlap; }
|
||||
auto end = std::min(it + max_len, input.end());
|
||||
auto &batch = batches.emplace_back(i, prefixTokens).batch;
|
||||
batch.insert(batch.end(), it, end);
|
||||
batch.push_back(eos_token);
|
||||
if (!doMean) { break; /* limit text to one chunk */ }
|
||||
}
|
||||
}
|
||||
inputs.clear();
|
||||
|
||||
// initialize batch
|
||||
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
|
||||
// n_texts x n_embd matrix
|
||||
const int32_t n_embd = llama_n_embd(d_ptr->model);
|
||||
std::vector<double> embeddingsSum(texts.size() * n_embd);
|
||||
std::vector<int> embeddingsSumTotal(texts.size());
|
||||
std::vector<int> queued_indices; // text indices of batches to be processed
|
||||
|
||||
auto decode = [this, &queued_indices, n_embd, &batch, &embeddingsSum, &embeddingsSumTotal, spec, dimensionality]() {
|
||||
if (llama_decode(d_ptr->ctx, batch) < 0)
|
||||
throw std::runtime_error("llama_decode failed");
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; ++i) {
|
||||
if (!batch.logits[i]) { continue; }
|
||||
int i_prompt = queued_indices[batch.seq_id[i][0]];
|
||||
auto *out = &embeddingsSum[i_prompt * n_embd];
|
||||
|
||||
// sequence embeddings aren't available when pooling_type is NONE
|
||||
auto *embd = llama_get_embeddings_seq(d_ptr->ctx, batch.seq_id[i][0]);
|
||||
if (!embd) { embd = llama_get_embeddings_ith(d_ptr->ctx, i); }
|
||||
assert(embd);
|
||||
|
||||
auto *embd_end = embd + n_embd;
|
||||
|
||||
// layer normalization for nomic-embed-text-v1.5
|
||||
if (spec && spec->matryoshkaCapable) {
|
||||
// normalize mean
|
||||
double mean = std::accumulate(embd, embd_end, 0.0) / n_embd;
|
||||
std::transform(embd, embd_end, embd, [mean](double f){ return f - mean; });
|
||||
|
||||
// unbiased sample variance, with Bessel's correction
|
||||
double variance = std::inner_product(embd, embd_end, embd, 0.0) / (n_embd - 1);
|
||||
|
||||
// trim to matryoshka dim
|
||||
embd_end = embd + dimensionality;
|
||||
|
||||
// normalize variance
|
||||
std::transform(embd, embd_end, embd, product(1.0 / std::sqrt(variance + 1e-5)));
|
||||
}
|
||||
|
||||
// L2 norm
|
||||
auto scale = getL2NormScale(embd, embd_end);
|
||||
std::transform(embd, embd_end, out, out, [scale](double e, double o){ return o + scale * e; });
|
||||
embeddingsSumTotal[i_prompt]++;
|
||||
}
|
||||
};
|
||||
|
||||
// break into batches
|
||||
for (auto &inp: batches) {
|
||||
// encode if at capacity
|
||||
if (batch.n_tokens + inp.batch.size() > n_batch) {
|
||||
decode();
|
||||
batch.n_tokens = 0;
|
||||
queued_indices.clear();
|
||||
}
|
||||
|
||||
// add to batch
|
||||
batch_add_seq(batch, inp.batch, queued_indices.size());
|
||||
queued_indices.push_back(inp.idx);
|
||||
}
|
||||
|
||||
// final batch
|
||||
decode();
|
||||
|
||||
for (unsigned i = 0; i < texts.size(); i++) {
|
||||
auto *embd = &embeddingsSum[i * n_embd];
|
||||
auto *embd_end = embd + dimensionality;
|
||||
int total = embeddingsSumTotal[i];
|
||||
|
||||
// average over chunks
|
||||
std::transform(embd, embd_end, embd, product(1.0 / total));
|
||||
|
||||
// L2 norm and copy
|
||||
auto scale = getL2NormScale(embd, embd_end);
|
||||
std::transform(embd, embd_end, embeddings, product(scale));
|
||||
embeddings += dimensionality;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define DLL_EXPORT __declspec(dllexport)
|
||||
#else
|
||||
@@ -556,23 +907,21 @@ DLL_EXPORT const char *get_build_variant() {
|
||||
|
||||
DLL_EXPORT bool magic_match(const char *fname) {
|
||||
auto * ctx = load_gguf(fname);
|
||||
auto arch = get_arch_name(ctx);
|
||||
std::string arch = get_arch_name(ctx);
|
||||
|
||||
bool valid = true;
|
||||
|
||||
static const std::vector<const char *> known_arches {
|
||||
"baichuan", "bloom", "codeshell", "falcon", "gemma", "gpt2", "llama", "mpt", "orion", "persimmon", "phi2",
|
||||
"plamo", "qwen", "qwen2", "refact", "stablelm", "starcoder"
|
||||
};
|
||||
|
||||
if (std::find(known_arches.begin(), known_arches.end(), arch) == known_arches.end()) {
|
||||
if (std::find(KNOWN_ARCHES.begin(), KNOWN_ARCHES.end(), arch) == KNOWN_ARCHES.end()) {
|
||||
// not supported by this version of llama.cpp
|
||||
if (!(arch == "gptj" || arch == "bert")) { // we support these via other modules
|
||||
if (arch != "gptj") { // we support this via another module
|
||||
std::cerr << __func__ << ": unsupported model architecture: " << arch << "\n";
|
||||
}
|
||||
valid = false;
|
||||
}
|
||||
|
||||
if (valid && is_embedding_arch(arch) && gguf_find_key(ctx, (arch + ".pooling_type").c_str()) < 0)
|
||||
valid = false; // old pre-llama.cpp embedding model, e.g. all-MiniLM-L6-v2-f16.gguf
|
||||
|
||||
gguf_free(ctx);
|
||||
return valid;
|
||||
}
|
||||
|
Reference in New Issue
Block a user