From ba4b28fcd508bef34bb46fc850786db6524a2bf0 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 27 Apr 2023 11:08:15 -0400 Subject: [PATCH] Move the promptCallback to own function. --- llm.cpp | 36 ++++++++++++++++++++++++------------ llm.h | 3 ++- llmodel/gptj.cpp | 19 ++++++++++--------- llmodel/gptj.h | 5 +++-- llmodel/llamamodel.cpp | 18 +++++++++--------- llmodel/llamamodel.h | 5 +++-- llmodel/llmodel.h | 5 +++-- llmodel/llmodel_c.cpp | 18 +++++++++++++----- llmodel/llmodel_c.h | 19 ++++++++++++++----- 9 files changed, 81 insertions(+), 47 deletions(-) diff --git a/llm.cpp b/llm.cpp index 1c489998..a218d40d 100644 --- a/llm.cpp +++ b/llm.cpp @@ -38,7 +38,7 @@ static QString modelFilePath(const QString &modelName) LLMObject::LLMObject() : QObject{nullptr} , m_llmodel(nullptr) - , m_responseTokens(0) + , m_promptResponseTokens(0) , m_responseLogits(0) , m_isRecalc(false) { @@ -133,12 +133,12 @@ bool LLMObject::isModelLoaded() const void LLMObject::regenerateResponse() { - s_ctx.n_past -= m_responseTokens; + s_ctx.n_past -= m_promptResponseTokens; s_ctx.n_past = std::max(0, s_ctx.n_past); // FIXME: This does not seem to be needed in my testing and llama models don't to it. Remove? s_ctx.logits.erase(s_ctx.logits.end() -= m_responseLogits, s_ctx.logits.end()); - s_ctx.tokens.erase(s_ctx.tokens.end() -= m_responseTokens, s_ctx.tokens.end()); - m_responseTokens = 0; + s_ctx.tokens.erase(s_ctx.tokens.end() -= m_promptResponseTokens, s_ctx.tokens.end()); + m_promptResponseTokens = 0; m_responseLogits = 0; m_response = std::string(); emit responseChanged(); @@ -146,7 +146,7 @@ void LLMObject::regenerateResponse() void LLMObject::resetResponse() { - m_responseTokens = 0; + m_promptResponseTokens = 0; m_responseLogits = 0; m_response = std::string(); emit responseChanged(); @@ -263,6 +263,18 @@ QList LLMObject::modelList() const return list; } +bool LLMObject::handlePrompt(int32_t token) +{ + if (s_ctx.tokens.size() == s_ctx.n_ctx) + s_ctx.tokens.erase(s_ctx.tokens.begin()); + s_ctx.tokens.push_back(token); + + // m_promptResponseTokens and m_responseLogits are related to last prompt/response not + // the entire context window which we can reset on regenerate prompt + ++m_promptResponseTokens; + return !m_stopGenerating; +} + bool LLMObject::handleResponse(int32_t token, const std::string &response) { #if 0 @@ -282,13 +294,12 @@ bool LLMObject::handleResponse(int32_t token, const std::string &response) s_ctx.tokens.erase(s_ctx.tokens.begin()); s_ctx.tokens.push_back(token); - // m_responseTokens and m_responseLogits are related to last prompt/response not + // m_promptResponseTokens and m_responseLogits are related to last prompt/response not // the entire context window which we can reset on regenerate prompt - ++m_responseTokens; - if (!response.empty()) { - m_response.append(response); - emit responseChanged(); - } + ++m_promptResponseTokens; + Q_ASSERT(!response.empty()); + m_response.append(response); + emit responseChanged(); // Stop generation if we encounter prompt or response tokens QString r = QString::fromStdString(m_response); @@ -315,6 +326,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in QString instructPrompt = prompt_template.arg(prompt); m_stopGenerating = false; + auto promptFunc = std::bind(&LLMObject::handlePrompt, this, std::placeholders::_1); auto responseFunc = std::bind(&LLMObject::handleResponse, this, std::placeholders::_1, std::placeholders::_2); auto recalcFunc = std::bind(&LLMObject::handleRecalculate, this, std::placeholders::_1); @@ -327,7 +339,7 @@ bool LLMObject::prompt(const QString &prompt, const QString &prompt_template, in s_ctx.n_batch = n_batch; s_ctx.repeat_penalty = repeat_penalty; s_ctx.repeat_last_n = repeat_penalty_tokens; - m_llmodel->prompt(instructPrompt.toStdString(), responseFunc, recalcFunc, s_ctx); + m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, s_ctx); m_responseLogits += s_ctx.logits.size() - logitsBefore; std::string trimmed = trim_whitespace(m_response); if (trimmed != m_response) { diff --git a/llm.h b/llm.h index 0bf8025e..089a63a4 100644 --- a/llm.h +++ b/llm.h @@ -58,13 +58,14 @@ Q_SIGNALS: private: void resetContextPrivate(); bool loadModelPrivate(const QString &modelName); + bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); bool handleRecalculate(bool isRecalc); private: LLModel *m_llmodel; std::string m_response; - quint32 m_responseTokens; + quint32 m_promptResponseTokens; quint32 m_responseLogits; QString m_modelName; QThread m_llmThread; diff --git a/llmodel/gptj.cpp b/llmodel/gptj.cpp index eef6d03a..36eeaf27 100644 --- a/llmodel/gptj.cpp +++ b/llmodel/gptj.cpp @@ -686,8 +686,9 @@ bool GPTJ::isModelLoaded() const } void GPTJ::prompt(const std::string &prompt, - std::function response, - std::function recalculate, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, PromptContext &promptCtx) { if (!isModelLoaded()) { @@ -708,7 +709,7 @@ void GPTJ::prompt(const std::string &prompt, promptCtx.n_ctx = d_ptr->model.hparams.n_ctx; if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { - response(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); + responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); std::cerr << "GPT-J ERROR: The prompt is" << embd_inp.size() << "tokens and the context window is" << promptCtx.n_ctx << "!\n"; return; @@ -741,7 +742,7 @@ void GPTJ::prompt(const std::string &prompt, std::cerr << "GPTJ: reached the end of the context window so resizing\n"; promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.n_past = promptCtx.tokens.size(); - recalculateContext(promptCtx, recalculate); + recalculateContext(promptCtx, recalculateCallback); assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); } @@ -750,10 +751,10 @@ void GPTJ::prompt(const std::string &prompt, std::cerr << "GPT-J ERROR: Failed to process prompt\n"; return; } - // We pass a null string for each token to see if the user has asked us to stop... + size_t tokens = batch_end - i; for (size_t t = 0; t < tokens; ++t) - if (!response(batch.at(t), "")) + if (!promptCallback(batch.at(t))) return; promptCtx.n_past += batch.size(); i = batch_end; @@ -790,8 +791,8 @@ void GPTJ::prompt(const std::string &prompt, std::cerr << "GPTJ: reached the end of the context window so resizing\n"; promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.n_past = promptCtx.tokens.size(); - recalculateContext(promptCtx, recalculate); - assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + recalculateContext(promptCtx, recalculateCallback); + assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); } const int64_t t_start_predict_us = ggml_time_us(); @@ -805,7 +806,7 @@ void GPTJ::prompt(const std::string &prompt, promptCtx.n_past += 1; // display text ++totalPredictions; - if (id == 50256 /*end of text*/ || !response(id, d_ptr->vocab.id_to_token[id])) + if (id == 50256 /*end of text*/ || !responseCallback(id, d_ptr->vocab.id_to_token[id])) goto stop_generating; } diff --git a/llmodel/gptj.h b/llmodel/gptj.h index 17cb069c..6f19dcd1 100644 --- a/llmodel/gptj.h +++ b/llmodel/gptj.h @@ -16,8 +16,9 @@ public: bool loadModel(const std::string &modelPath, std::istream &fin) override; bool isModelLoaded() const override; void prompt(const std::string &prompt, - std::function response, - std::function recalculate, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, PromptContext &ctx) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() override; diff --git a/llmodel/llamamodel.cpp b/llmodel/llamamodel.cpp index 61318592..89c230fc 100644 --- a/llmodel/llamamodel.cpp +++ b/llmodel/llamamodel.cpp @@ -80,8 +80,9 @@ bool LLamaModel::isModelLoaded() const } void LLamaModel::prompt(const std::string &prompt, - std::function response, - std::function recalculate, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, PromptContext &promptCtx) { if (!isModelLoaded()) { @@ -102,7 +103,7 @@ void LLamaModel::prompt(const std::string &prompt, promptCtx.n_ctx = llama_n_ctx(d_ptr->ctx); if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { - response(-1, "The prompt size exceeds the context window size and cannot be processed."); + responseCallback(-1, "The prompt size exceeds the context window size and cannot be processed."); std::cerr << "LLAMA ERROR: The prompt is" << embd_inp.size() << "tokens and the context window is" << promptCtx.n_ctx << "!\n"; return; @@ -128,7 +129,7 @@ void LLamaModel::prompt(const std::string &prompt, std::cerr << "LLAMA: reached the end of the context window so resizing\n"; promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.n_past = promptCtx.tokens.size(); - recalculateContext(promptCtx, recalculate); + recalculateContext(promptCtx, recalculateCallback); assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); } @@ -137,10 +138,9 @@ void LLamaModel::prompt(const std::string &prompt, return; } - // We pass a null string for each token to see if the user has asked us to stop... size_t tokens = batch_end - i; for (size_t t = 0; t < tokens; ++t) - if (!response(batch.at(t), "")) + if (!promptCallback(batch.at(t))) return; promptCtx.n_past += batch.size(); i = batch_end; @@ -162,8 +162,8 @@ void LLamaModel::prompt(const std::string &prompt, std::cerr << "LLAMA: reached the end of the context window so resizing\n"; promptCtx.tokens.erase(promptCtx.tokens.begin(), promptCtx.tokens.begin() + erasePoint); promptCtx.n_past = promptCtx.tokens.size(); - recalculateContext(promptCtx, recalculate); - assert(promptCtx.n_past + batch.size() <= promptCtx.n_ctx); + recalculateContext(promptCtx, recalculateCallback); + assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); } if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) { @@ -174,7 +174,7 @@ void LLamaModel::prompt(const std::string &prompt, promptCtx.n_past += 1; // display text ++totalPredictions; - if (id == llama_token_eos() || !response(id, llama_token_to_str(d_ptr->ctx, id))) + if (id == llama_token_eos() || !responseCallback(id, llama_token_to_str(d_ptr->ctx, id))) return; } } diff --git a/llmodel/llamamodel.h b/llmodel/llamamodel.h index 163260bb..c97f80b7 100644 --- a/llmodel/llamamodel.h +++ b/llmodel/llamamodel.h @@ -16,8 +16,9 @@ public: bool loadModel(const std::string &modelPath, std::istream &fin) override; bool isModelLoaded() const override; void prompt(const std::string &prompt, - std::function response, - std::function recalculate, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, PromptContext &ctx) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() override; diff --git a/llmodel/llmodel.h b/llmodel/llmodel.h index 5945eb98..0cc53689 100644 --- a/llmodel/llmodel.h +++ b/llmodel/llmodel.h @@ -29,8 +29,9 @@ public: // window }; virtual void prompt(const std::string &prompt, - std::function response, - std::function recalculate, + std::function promptCallback, + std::function responseCallback, + std::function recalculateCallback, PromptContext &ctx) = 0; virtual void setThreadCount(int32_t n_threads) {} virtual int32_t threadCount() { return 1; } diff --git a/llmodel/llmodel_c.cpp b/llmodel/llmodel_c.cpp index ec430dcb..46eb1a7d 100644 --- a/llmodel/llmodel_c.cpp +++ b/llmodel/llmodel_c.cpp @@ -49,6 +49,11 @@ bool llmodel_isModelLoaded(llmodel_model model) } // Wrapper functions for the C callbacks +bool prompt_wrapper(int32_t token_id, void *user_data) { + llmodel_prompt_callback callback = reinterpret_cast(user_data); + return callback(token_id); +} + bool response_wrapper(int32_t token_id, const std::string &response, void *user_data) { llmodel_response_callback callback = reinterpret_cast(user_data); return callback(token_id, response.c_str()); @@ -60,17 +65,20 @@ bool recalculate_wrapper(bool is_recalculating, void *user_data) { } void llmodel_prompt(llmodel_model model, const char *prompt, - llmodel_response_callback response, - llmodel_recalculate_callback recalculate, + llmodel_response_callback prompt_callback, + llmodel_response_callback response_callback, + llmodel_recalculate_callback recalculate_callback, llmodel_prompt_context *ctx) { LLModelWrapper *wrapper = reinterpret_cast(model); // Create std::function wrappers that call the C function pointers + std::function prompt_func = + std::bind(&prompt_wrapper, std::placeholders::_1, reinterpret_cast(prompt_callback)); std::function response_func = - std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast(response)); + std::bind(&response_wrapper, std::placeholders::_1, std::placeholders::_2, reinterpret_cast(response_callback)); std::function recalc_func = - std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast(recalculate)); + std::bind(&recalculate_wrapper, std::placeholders::_1, reinterpret_cast(recalculate_callback)); // Copy the C prompt context wrapper->promptContext.n_past = ctx->n_past; @@ -85,7 +93,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt, wrapper->promptContext.contextErase = ctx->context_erase; // Call the C++ prompt method - wrapper->llModel->prompt(prompt, response_func, recalc_func, wrapper->promptContext); + wrapper->llModel->prompt(prompt, prompt_func, response_func, recalc_func, wrapper->promptContext); // Update the C context by giving access to the wrappers raw pointers to std::vector data // which involves no copies diff --git a/llmodel/llmodel_c.h b/llmodel/llmodel_c.h index b0b3fa95..45cc9cd2 100644 --- a/llmodel/llmodel_c.h +++ b/llmodel/llmodel_c.h @@ -37,10 +37,17 @@ typedef struct { float context_erase; // percent of context to erase if we exceed the context window } llmodel_prompt_context; +/** + * Callback type for prompt processing. + * @param token_id The token id of the prompt. + * @return a bool indicating whether the model should keep processing. + */ +typedef bool (*llmodel_prompt_callback)(int32_t token_id); + /** * Callback type for response. * @param token_id The token id of the response. - * @param response The response string. + * @param response The response string. NOTE: a token_id of -1 indicates the string is an error string. * @return a bool indicating whether the model should keep generating. */ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response); @@ -95,13 +102,15 @@ bool llmodel_isModelLoaded(llmodel_model model); * Generate a response using the model. * @param model A pointer to the llmodel_model instance. * @param prompt A string representing the input prompt. - * @param response A callback function for handling the generated response. - * @param recalculate A callback function for handling recalculation requests. + * @param prompt_callback A callback function for handling the processing of prompt. + * @param response_callback A callback function for handling the generated response. + * @param recalculate_callback A callback function for handling recalculation requests. * @param ctx A pointer to the llmodel_prompt_context structure. */ void llmodel_prompt(llmodel_model model, const char *prompt, - llmodel_response_callback response, - llmodel_recalculate_callback recalculate, + llmodel_response_callback prompt_callback, + llmodel_response_callback response_callback, + llmodel_recalculate_callback recalculate_callback, llmodel_prompt_context *ctx); /**