From 7ee32d605f876c9757555d3218f8ea4b7eb3ee93 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Thu, 1 Jun 2023 23:15:58 -0400 Subject: [PATCH] Trying to shrink the copy+paste code and do more code sharing between backend model impl. --- gpt4all-backend/gptj.cpp | 31 ++++--------------------------- gpt4all-backend/gptj_impl.h | 5 +---- gpt4all-backend/llamamodel.cpp | 28 ++++------------------------ gpt4all-backend/llamamodel_impl.h | 5 +---- gpt4all-backend/llmodel.cpp | 23 +++++++++++++++++++++++ gpt4all-backend/llmodel.h | 4 ++-- gpt4all-backend/mpt.cpp | 31 ++++--------------------------- gpt4all-backend/mpt_impl.h | 5 +---- gpt4all-chat/chatgpt.h | 5 +---- 9 files changed, 41 insertions(+), 96 deletions(-) diff --git a/gpt4all-backend/gptj.cpp b/gpt4all-backend/gptj.cpp index 5f0fbbe5..dff8878f 100644 --- a/gpt4all-backend/gptj.cpp +++ b/gpt4all-backend/gptj.cpp @@ -944,8 +944,7 @@ void GPTJ::prompt(const std::string &prompt, assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); } - if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, - d_ptr->mem_per_token)) { + if (!evalTokens(promptCtx, batch)) { std::cerr << "GPT-J ERROR: Failed to process prompt\n"; return; } @@ -995,8 +994,7 @@ void GPTJ::prompt(const std::string &prompt, assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); } - if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, - d_ptr->mem_per_token)) { + if (!evalTokens(promptCtx, { id })) { std::cerr << "GPT-J ERROR: Failed to predict next token\n"; return; } @@ -1042,30 +1040,9 @@ void GPTJ::prompt(const std::string &prompt, } } -void GPTJ::recalculateContext(PromptContext &promptCtx, std::function recalculate) +bool GPTJ::evalTokens(PromptContext &ctx, const std::vector &tokens) { - size_t i = 0; - promptCtx.n_past = 0; - while (i < promptCtx.tokens.size()) { - size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size()); - std::vector batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end); - - assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); - - if (!gptj_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, - d_ptr->mem_per_token)) { - std::cerr << "GPTJ ERROR: Failed to process prompt\n"; - goto stop_generating; - } - promptCtx.n_past += batch.size(); - if (!recalculate(true)) - goto stop_generating; - i = batch_end; - } - assert(promptCtx.n_past == int32_t(promptCtx.tokens.size())); - -stop_generating: - recalculate(false); + return gptj_eval(*d_ptr->model, d_ptr->n_threads, ctx.n_past, tokens, ctx.logits, d_ptr->mem_per_token); } #if defined(_WIN32) diff --git a/gpt4all-backend/gptj_impl.h b/gpt4all-backend/gptj_impl.h index c5c99d21..4b209bd2 100644 --- a/gpt4all-backend/gptj_impl.h +++ b/gpt4all-backend/gptj_impl.h @@ -25,13 +25,10 @@ public: std::function responseCallback, std::function recalculateCallback, PromptContext &ctx) override; + bool evalTokens(PromptContext &ctx, const std::vector &tokens) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() const override; -protected: - void recalculateContext(PromptContext &promptCtx, - std::function recalculate) override; - private: GPTJPrivate *d_ptr; }; diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index c819aba4..86a7cc63 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -216,7 +216,7 @@ void LLamaModel::prompt(const std::string &prompt, assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); } - if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) { + if (!evalTokens(promptCtx, batch)) { std::cerr << "LLAMA ERROR: Failed to process prompt\n"; return; } @@ -258,7 +258,7 @@ void LLamaModel::prompt(const std::string &prompt, assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); } - if (llama_eval(d_ptr->ctx, &id, 1, promptCtx.n_past, d_ptr->n_threads)) { + if (!evalTokens(promptCtx, { id })) { std::cerr << "LLAMA ERROR: Failed to predict next token\n"; return; } @@ -305,29 +305,9 @@ void LLamaModel::prompt(const std::string &prompt, } } -void LLamaModel::recalculateContext(PromptContext &promptCtx, std::function recalculate) +bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector &tokens) { - size_t i = 0; - promptCtx.n_past = 0; - while (i < promptCtx.tokens.size()) { - size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size()); - std::vector batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end); - - assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); - - if (llama_eval(d_ptr->ctx, batch.data(), batch.size(), promptCtx.n_past, d_ptr->n_threads)) { - std::cerr << "LLAMA ERROR: Failed to process prompt\n"; - goto stop_generating; - } - promptCtx.n_past += batch.size(); - if (!recalculate(true)) - goto stop_generating; - i = batch_end; - } - assert(promptCtx.n_past == int32_t(promptCtx.tokens.size())); - -stop_generating: - recalculate(false); + return llama_eval(d_ptr->ctx, tokens.data(), tokens.size(), ctx.n_past, d_ptr->n_threads) == 0; } #if defined(_WIN32) diff --git a/gpt4all-backend/llamamodel_impl.h b/gpt4all-backend/llamamodel_impl.h index a8f18936..3c27fff8 100644 --- a/gpt4all-backend/llamamodel_impl.h +++ b/gpt4all-backend/llamamodel_impl.h @@ -25,13 +25,10 @@ public: std::function responseCallback, std::function recalculateCallback, PromptContext &ctx) override; + bool evalTokens(PromptContext &ctx, const std::vector &tokens) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() const override; -protected: - void recalculateContext(PromptContext &promptCtx, - std::function recalculate) override; - private: LLamaPrivate *d_ptr; }; diff --git a/gpt4all-backend/llmodel.cpp b/gpt4all-backend/llmodel.cpp index b78370ae..0de32300 100644 --- a/gpt4all-backend/llmodel.cpp +++ b/gpt4all-backend/llmodel.cpp @@ -1,6 +1,7 @@ #include "llmodel.h" #include "dlhandle.h" +#include #include #include #include @@ -95,6 +96,28 @@ const LLModel::Implementation* LLModel::implementation(std::ifstream& f, const s return nullptr; } +void LLModel::recalculateContext(PromptContext &promptCtx, std::function recalculate) { + size_t i = 0; + promptCtx.n_past = 0; + while (i < promptCtx.tokens.size()) { + size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size()); + std::vector batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end); + assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); + if (!evalTokens(promptCtx, batch)) { + std::cerr << "LLModel ERROR: Failed to process prompt\n"; + goto stop_generating; + } + promptCtx.n_past += batch.size(); + if (!recalculate(true)) + goto stop_generating; + i = batch_end; + } + assert(promptCtx.n_past == int32_t(promptCtx.tokens.size())); + +stop_generating: + recalculate(false); +} + LLModel *LLModel::construct(const std::string &modelPath, std::string buildVariant) { //TODO: Auto-detect CUDA/OpenCL if (buildVariant == "auto") { diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index f57a1e8a..45a3a3c2 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -64,6 +64,7 @@ public: std::function responseCallback, std::function recalculateCallback, PromptContext &ctx) = 0; + virtual bool evalTokens(PromptContext &ctx, const std::vector &tokens) = 0; virtual void setThreadCount(int32_t /*n_threads*/) {} virtual int32_t threadCount() const { return 1; } @@ -78,7 +79,6 @@ public: protected: const Implementation *m_implementation = nullptr; - virtual void recalculateContext(PromptContext &promptCtx, - std::function recalculate) = 0; + void recalculateContext(PromptContext &promptCtx, std::function recalculate); }; #endif // LLMODEL_H diff --git a/gpt4all-backend/mpt.cpp b/gpt4all-backend/mpt.cpp index b1bbf6e9..b37dc2b6 100644 --- a/gpt4all-backend/mpt.cpp +++ b/gpt4all-backend/mpt.cpp @@ -869,8 +869,7 @@ void MPT::prompt(const std::string &prompt, assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); } - if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, - d_ptr->mem_per_token)) { + if (!evalTokens(promptCtx, batch)) { std::cerr << "GPT-J ERROR: Failed to process prompt\n"; return; } @@ -920,8 +919,7 @@ void MPT::prompt(const std::string &prompt, assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); } - if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, { id }, promptCtx.logits, - d_ptr->mem_per_token)) { + if (!evalTokens(promptCtx, { id })) { std::cerr << "GPT-J ERROR: Failed to predict next token\n"; return; } @@ -971,30 +969,9 @@ void MPT::prompt(const std::string &prompt, } } -void MPT::recalculateContext(PromptContext &promptCtx, std::function recalculate) +bool MPT::evalTokens(PromptContext &ctx, const std::vector &tokens) { - size_t i = 0; - promptCtx.n_past = 0; - while (i < promptCtx.tokens.size()) { - size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size()); - std::vector batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end); - - assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); - - if (!mpt_eval(*d_ptr->model, d_ptr->n_threads, promptCtx.n_past, batch, promptCtx.logits, - d_ptr->mem_per_token)) { - std::cerr << "MPT ERROR: Failed to process prompt\n"; - goto stop_generating; - } - promptCtx.n_past += batch.size(); - if (!recalculate(true)) - goto stop_generating; - i = batch_end; - } - assert(promptCtx.n_past == int32_t(promptCtx.tokens.size())); - -stop_generating: - recalculate(false); + return mpt_eval(*d_ptr->model, d_ptr->n_threads, ctx.n_past, tokens, ctx.logits, d_ptr->mem_per_token); } #if defined(_WIN32) diff --git a/gpt4all-backend/mpt_impl.h b/gpt4all-backend/mpt_impl.h index 31095afd..f645b8bf 100644 --- a/gpt4all-backend/mpt_impl.h +++ b/gpt4all-backend/mpt_impl.h @@ -25,13 +25,10 @@ public: std::function responseCallback, std::function recalculateCallback, PromptContext &ctx) override; + bool evalTokens(PromptContext &ctx, const std::vector &tokens) override; void setThreadCount(int32_t n_threads) override; int32_t threadCount() const override; -protected: - void recalculateContext(PromptContext &promptCtx, - std::function recalculate) override; - private: MPTPrivate *d_ptr; }; diff --git a/gpt4all-chat/chatgpt.h b/gpt4all-chat/chatgpt.h index 62913ded..eb20a722 100644 --- a/gpt4all-chat/chatgpt.h +++ b/gpt4all-chat/chatgpt.h @@ -24,6 +24,7 @@ public: std::function responseCallback, std::function recalculateCallback, PromptContext &ctx) override; + bool evalTokens(PromptContext &ctx, const std::vector &tokens) override { return true; } void setThreadCount(int32_t n_threads) override; int32_t threadCount() const override; @@ -33,10 +34,6 @@ public: QList context() const { return m_context; } void setContext(const QList &context) { m_context = context; } -protected: - void recalculateContext(PromptContext &promptCtx, - std::function recalculate) override {} - private Q_SLOTS: void handleFinished(); void handleReadyRead();