mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-07 19:40:21 +00:00
python: embedding cancel callback for nomic client dynamic mode (#2214)
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
@@ -158,7 +158,7 @@ static int32_t get_arch_key_u32(std::string const &modelPath, std::string const
|
||||
|
||||
struct LLamaPrivate {
|
||||
const std::string modelPath;
|
||||
bool modelLoaded;
|
||||
bool modelLoaded = false;
|
||||
int device = -1;
|
||||
llama_model *model = nullptr;
|
||||
llama_context *ctx = nullptr;
|
||||
@@ -166,12 +166,11 @@ struct LLamaPrivate {
|
||||
llama_context_params ctx_params;
|
||||
int64_t n_threads = 0;
|
||||
std::vector<LLModel::Token> end_tokens;
|
||||
const char *backend_name = nullptr;
|
||||
};
|
||||
|
||||
LLamaModel::LLamaModel()
|
||||
: d_ptr(new LLamaPrivate) {
|
||||
d_ptr->modelLoaded = false;
|
||||
}
|
||||
: d_ptr(new LLamaPrivate) {}
|
||||
|
||||
// default hparams (LLaMA 7B)
|
||||
struct llama_file_hparams {
|
||||
@@ -291,6 +290,8 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
||||
d_ptr->model_params.progress_callback = &LLModel::staticProgressCallback;
|
||||
d_ptr->model_params.progress_callback_user_data = this;
|
||||
|
||||
d_ptr->backend_name = "cpu"; // default
|
||||
|
||||
#ifdef GGML_USE_KOMPUTE
|
||||
if (d_ptr->device != -1) {
|
||||
d_ptr->model_params.main_gpu = d_ptr->device;
|
||||
@@ -301,6 +302,7 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
||||
|
||||
if (llama_verbose()) {
|
||||
std::cerr << "llama.cpp: using Metal" << std::endl;
|
||||
d_ptr->backend_name = "metal";
|
||||
}
|
||||
|
||||
// always fully offload on Metal
|
||||
@@ -364,6 +366,7 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
|
||||
#ifdef GGML_USE_KOMPUTE
|
||||
if (usingGPUDevice() && ggml_vk_has_device()) {
|
||||
std::cerr << "llama.cpp: using Vulkan on " << ggml_vk_current_device().name << std::endl;
|
||||
d_ptr->backend_name = "kompute";
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -674,7 +677,7 @@ void LLamaModel::embed(
|
||||
|
||||
void LLamaModel::embed(
|
||||
const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality,
|
||||
size_t *tokenCount, bool doMean, bool atlas
|
||||
size_t *tokenCount, bool doMean, bool atlas, LLModel::EmbedCancelCallback *cancelCb
|
||||
) {
|
||||
if (!d_ptr->model)
|
||||
throw std::logic_error("no model is loaded");
|
||||
@@ -712,7 +715,7 @@ void LLamaModel::embed(
|
||||
throw std::invalid_argument(ss.str());
|
||||
}
|
||||
|
||||
embedInternal(texts, embeddings, *prefix, dimensionality, tokenCount, doMean, atlas, spec);
|
||||
embedInternal(texts, embeddings, *prefix, dimensionality, tokenCount, doMean, atlas, cancelCb, spec);
|
||||
}
|
||||
|
||||
// MD5 hash of "nomic empty"
|
||||
@@ -730,7 +733,7 @@ double getL2NormScale(T *start, T *end) {
|
||||
|
||||
void LLamaModel::embedInternal(
|
||||
const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
|
||||
size_t *tokenCount, bool doMean, bool atlas, const EmbModelSpec *spec
|
||||
size_t *tokenCount, bool doMean, bool atlas, LLModel::EmbedCancelCallback *cancelCb, const EmbModelSpec *spec
|
||||
) {
|
||||
typedef std::vector<LLModel::Token> TokenString;
|
||||
static constexpr int32_t atlasMaxLength = 8192;
|
||||
@@ -822,6 +825,23 @@ void LLamaModel::embedInternal(
|
||||
}
|
||||
inputs.clear();
|
||||
|
||||
if (cancelCb) {
|
||||
// copy of batching code below, but just count tokens instead of running inference
|
||||
unsigned nBatchTokens = 0;
|
||||
std::vector<unsigned> batchSizes;
|
||||
for (const auto &inp: batches) {
|
||||
if (nBatchTokens + inp.batch.size() > n_batch) {
|
||||
batchSizes.push_back(nBatchTokens);
|
||||
nBatchTokens = 0;
|
||||
}
|
||||
nBatchTokens += inp.batch.size();
|
||||
}
|
||||
batchSizes.push_back(nBatchTokens);
|
||||
if (cancelCb(batchSizes.data(), batchSizes.size(), d_ptr->backend_name)) {
|
||||
throw std::runtime_error("operation was canceled");
|
||||
}
|
||||
}
|
||||
|
||||
// initialize batch
|
||||
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
||||
|
||||
@@ -871,7 +891,7 @@ void LLamaModel::embedInternal(
|
||||
};
|
||||
|
||||
// break into batches
|
||||
for (auto &inp: batches) {
|
||||
for (const auto &inp: batches) {
|
||||
// encode if at capacity
|
||||
if (batch.n_tokens + inp.batch.size() > n_batch) {
|
||||
decode();
|
||||
|
@@ -39,7 +39,8 @@ public:
|
||||
size_t embeddingSize() const override;
|
||||
// user-specified prefix
|
||||
void embed(const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix,
|
||||
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
|
||||
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false,
|
||||
EmbedCancelCallback *cancelCb = nullptr) override;
|
||||
// automatic prefix
|
||||
void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval, int dimensionality = -1,
|
||||
size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false) override;
|
||||
@@ -61,7 +62,8 @@ protected:
|
||||
int32_t layerCount(std::string const &modelPath) const override;
|
||||
|
||||
void embedInternal(const std::vector<std::string> &texts, float *embeddings, std::string prefix, int dimensionality,
|
||||
size_t *tokenCount, bool doMean, bool atlas, const EmbModelSpec *spec);
|
||||
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb,
|
||||
const EmbModelSpec *spec);
|
||||
};
|
||||
|
||||
#endif // LLAMAMODEL_H
|
||||
|
@@ -105,12 +105,15 @@ public:
|
||||
bool special = false,
|
||||
std::string *fakeReply = nullptr);
|
||||
|
||||
using EmbedCancelCallback = bool(unsigned *batchSizes, unsigned nBatch, const char *backend);
|
||||
|
||||
virtual size_t embeddingSize() const {
|
||||
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
|
||||
}
|
||||
// user-specified prefix
|
||||
virtual void embed(const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix,
|
||||
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false);
|
||||
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false,
|
||||
EmbedCancelCallback *cancelCb = nullptr);
|
||||
// automatic prefix
|
||||
virtual void embed(const std::vector<std::string> &texts, float *embeddings, bool isRetrieval,
|
||||
int dimensionality = -1, size_t *tokenCount = nullptr, bool doMean = true, bool atlas = false);
|
||||
|
@@ -159,7 +159,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
|
||||
float *llmodel_embed(
|
||||
llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix, int dimensionality,
|
||||
size_t *token_count, bool do_mean, bool atlas, const char **error
|
||||
size_t *token_count, bool do_mean, bool atlas, llmodel_emb_cancel_callback cancel_cb, const char **error
|
||||
) {
|
||||
auto *wrapper = static_cast<LLModelWrapper *>(model);
|
||||
|
||||
@@ -185,7 +185,7 @@ float *llmodel_embed(
|
||||
if (prefix) { prefixStr = prefix; }
|
||||
|
||||
embedding = new float[embd_size];
|
||||
wrapper->llModel->embed(textsVec, embedding, prefixStr, dimensionality, token_count, do_mean, atlas);
|
||||
wrapper->llModel->embed(textsVec, embedding, prefixStr, dimensionality, token_count, do_mean, atlas, cancel_cb);
|
||||
} catch (std::exception const &e) {
|
||||
llmodel_set_error(error, e.what());
|
||||
return nullptr;
|
||||
|
@@ -82,6 +82,15 @@ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response
|
||||
*/
|
||||
typedef bool (*llmodel_recalculate_callback)(bool is_recalculating);
|
||||
|
||||
/**
|
||||
* Embedding cancellation callback for use with llmodel_embed.
|
||||
* @param batch_sizes The number of tokens in each batch that will be embedded.
|
||||
* @param n_batch The number of batches that will be embedded.
|
||||
* @param backend The backend that will be used for embedding. One of "cpu", "kompute", or "metal".
|
||||
* @return True to cancel llmodel_embed, false to continue.
|
||||
*/
|
||||
typedef bool (*llmodel_emb_cancel_callback)(unsigned *batch_sizes, unsigned n_batch, const char *backend);
|
||||
|
||||
/**
|
||||
* Create a llmodel instance.
|
||||
* Recognises correct model type from file at model_path
|
||||
@@ -198,12 +207,14 @@ void llmodel_prompt(llmodel_model model, const char *prompt,
|
||||
* truncate.
|
||||
* @param atlas Try to be fully compatible with the Atlas API. Currently, this means texts longer than 8192 tokens with
|
||||
* long_text_mode="mean" will raise an error. Disabled by default.
|
||||
* @param cancel_cb Cancellation callback, or NULL. See the documentation of llmodel_emb_cancel_callback.
|
||||
* @param error Return location for a malloc()ed string that will be set on error, or NULL.
|
||||
* @return A pointer to an array of floating point values passed to the calling method which then will
|
||||
* be responsible for lifetime of this memory. NULL if an error occurred.
|
||||
*/
|
||||
float *llmodel_embed(llmodel_model model, const char **texts, size_t *embedding_size, const char *prefix,
|
||||
int dimensionality, size_t *token_count, bool do_mean, bool atlas, const char **error);
|
||||
int dimensionality, size_t *token_count, bool do_mean, bool atlas,
|
||||
llmodel_emb_cancel_callback cancel_cb, const char **error);
|
||||
|
||||
/**
|
||||
* Frees the memory allocated by the llmodel_embedding function.
|
||||
|
@@ -270,7 +270,7 @@ void LLModel::generateResponse(std::function<bool(int32_t, const std::string&)>
|
||||
|
||||
void LLModel::embed(
|
||||
const std::vector<std::string> &texts, float *embeddings, std::optional<std::string> prefix, int dimensionality,
|
||||
size_t *tokenCount, bool doMean, bool atlas
|
||||
size_t *tokenCount, bool doMean, bool atlas, EmbedCancelCallback *cancelCb
|
||||
) {
|
||||
(void)texts;
|
||||
(void)embeddings;
|
||||
@@ -279,6 +279,7 @@ void LLModel::embed(
|
||||
(void)tokenCount;
|
||||
(void)doMean;
|
||||
(void)atlas;
|
||||
(void)cancelCb;
|
||||
throw std::logic_error(std::string(implementation().modelType()) + " does not support embeddings");
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user