diff --git a/gpt4all-backend/CMakeLists.txt b/gpt4all-backend/CMakeLists.txt index b47c4505..5862c726 100644 --- a/gpt4all-backend/CMakeLists.txt +++ b/gpt4all-backend/CMakeLists.txt @@ -33,7 +33,7 @@ set(LLMODEL_VERSION_PATCH 0) set(LLMODEL_VERSION "${LLMODEL_VERSION_MAJOR}.${LLMODEL_VERSION_MINOR}.${LLMODEL_VERSION_PATCH}") project(llmodel VERSION ${LLMODEL_VERSION} LANGUAGES CXX C) -set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD 23) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}) set(BUILD_SHARED_LIBS ON) diff --git a/gpt4all-backend/llamamodel.cpp b/gpt4all-backend/llamamodel.cpp index cab0e75a..f07a05e8 100644 --- a/gpt4all-backend/llamamodel.cpp +++ b/gpt4all-backend/llamamodel.cpp @@ -531,10 +531,7 @@ size_t LLamaModel::restoreState(const uint8_t *src) std::vector LLamaModel::tokenize(PromptContext &ctx, const std::string &str, bool special) { bool atStart = m_tokenize_last_token == -1; - bool insertSpace = atStart || ( - llama_token_get_attr(d_ptr->model, m_tokenize_last_token) - & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN) - ); + bool insertSpace = atStart || isSpecialToken(m_tokenize_last_token); std::vector fres(str.length() + 4); int32_t fres_len = llama_tokenize_gpt4all( d_ptr->model, str.c_str(), str.length(), fres.data(), fres.size(), /*add_special*/ atStart, @@ -546,6 +543,12 @@ std::vector LLamaModel::tokenize(PromptContext &ctx, const std:: return fres; } +bool LLamaModel::isSpecialToken(Token id) const +{ + return llama_token_get_attr(d_ptr->model, id) + & (LLAMA_TOKEN_ATTR_CONTROL | LLAMA_TOKEN_ATTR_USER_DEFINED | LLAMA_TOKEN_ATTR_UNKNOWN); +} + std::string LLamaModel::tokenToString(Token id) const { std::vector result(8, 0); @@ -595,6 +598,30 @@ bool LLamaModel::evalTokens(PromptContext &ctx, const std::vector &toke return res == 0; } +void LLamaModel::shiftContext(PromptContext &promptCtx) +{ + // infinite text generation via context shifting + + // erase up to n_ctx*contextErase tokens + int n_keep = shouldAddBOS(); + int n_past = promptCtx.n_past; + int n_discard = std::min(n_past - n_keep, int(promptCtx.n_ctx * promptCtx.contextErase)); + + assert(n_discard > 0); + if (n_discard <= 0) + return; + + std::cerr << "Llama: context full, swapping: n_past = " << n_past << ", n_keep = " << n_keep + << ", n_discard = " << n_discard << "\n"; + + // erase the first n_discard tokens from the context + llama_kv_cache_seq_rm (d_ptr->ctx, 0, n_keep, n_keep + n_discard); + llama_kv_cache_seq_add(d_ptr->ctx, 0, n_keep + n_discard, n_past, -n_discard); + + promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard); + promptCtx.n_past = promptCtx.tokens.size(); +} + int32_t LLamaModel::contextLength() const { return llama_n_ctx(d_ptr->ctx); diff --git a/gpt4all-backend/llamamodel_impl.h b/gpt4all-backend/llamamodel_impl.h index 019e5532..7c698ffa 100644 --- a/gpt4all-backend/llamamodel_impl.h +++ b/gpt4all-backend/llamamodel_impl.h @@ -6,7 +6,6 @@ #include "llmodel.h" -#include #include #include #include @@ -54,9 +53,11 @@ private: protected: std::vector tokenize(PromptContext &ctx, const std::string &str, bool special) override; + bool isSpecialToken(Token id) const override; std::string tokenToString(Token id) const override; Token sampleToken(PromptContext &ctx) const override; bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override; + void shiftContext(PromptContext &promptCtx) override; int32_t contextLength() const override; const std::vector &endTokens() const override; bool shouldAddBOS() const override; diff --git a/gpt4all-backend/llmodel.h b/gpt4all-backend/llmodel.h index f95dc3a8..04a510dc 100644 --- a/gpt4all-backend/llmodel.h +++ b/gpt4all-backend/llmodel.h @@ -134,7 +134,7 @@ public: int32_t n_batch = 9; float repeat_penalty = 1.10f; int32_t repeat_last_n = 64; // last n tokens to penalize - float contextErase = 0.75f; // percent of context to erase if we exceed the context window + float contextErase = 0.5f; // percent of context to erase if we exceed the context window }; using ProgressCallback = std::function; @@ -159,7 +159,7 @@ public: const std::string &promptTemplate, std::function promptCallback, std::function responseCallback, - std::function recalculateCallback, + bool allowContextShift, PromptContext &ctx, bool special = false, std::string *fakeReply = nullptr); @@ -213,9 +213,11 @@ protected: // These are pure virtual because subclasses need to implement as the default implementation of // 'prompt' above calls these functions virtual std::vector tokenize(PromptContext &ctx, const std::string &str, bool special = false) = 0; + virtual bool isSpecialToken(Token id) const = 0; virtual std::string tokenToString(Token id) const = 0; virtual Token sampleToken(PromptContext &ctx) const = 0; virtual bool evalTokens(PromptContext &ctx, const std::vector &tokens) const = 0; + virtual void shiftContext(PromptContext &promptCtx) = 0; virtual int32_t contextLength() const = 0; virtual const std::vector &endTokens() const = 0; virtual bool shouldAddBOS() const = 0; @@ -232,10 +234,6 @@ protected: return -1; } - // This is a helper function called from the default implementation of 'prompt' but it can be - // shared by all base classes so it isn't virtual - void recalculateContext(PromptContext &promptCtx, std::function recalculate); - const Implementation *m_implementation = nullptr; ProgressCallback m_progressCallback; @@ -249,11 +247,11 @@ protected: bool decodePrompt(std::function promptCallback, std::function responseCallback, - std::function recalculateCallback, + bool allowContextShift, PromptContext &promptCtx, std::vector embd_inp); void generateResponse(std::function responseCallback, - std::function recalculateCallback, + bool allowContextShift, PromptContext &promptCtx); Token m_tokenize_last_token = -1; // not serialized diff --git a/gpt4all-backend/llmodel_c.cpp b/gpt4all-backend/llmodel_c.cpp index d6ba95e8..f3fd68ff 100644 --- a/gpt4all-backend/llmodel_c.cpp +++ b/gpt4all-backend/llmodel_c.cpp @@ -106,7 +106,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt, const char *prompt_template, llmodel_prompt_callback prompt_callback, llmodel_response_callback response_callback, - llmodel_recalculate_callback recalculate_callback, + bool allow_context_shift, llmodel_prompt_context *ctx, bool special, const char *fake_reply) @@ -135,7 +135,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt, auto *fake_reply_p = fake_reply ? &fake_reply_str : nullptr; // Call the C++ prompt method - wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, recalculate_callback, + wrapper->llModel->prompt(prompt, prompt_template, prompt_callback, response_func, allow_context_shift, wrapper->promptContext, special, fake_reply_p); // Update the C context by giving access to the wrappers raw pointers to std::vector data diff --git a/gpt4all-backend/llmodel_c.h b/gpt4all-backend/llmodel_c.h index 75cfb862..327bea2e 100644 --- a/gpt4all-backend/llmodel_c.h +++ b/gpt4all-backend/llmodel_c.h @@ -74,13 +74,6 @@ typedef bool (*llmodel_prompt_callback)(int32_t token_id); */ typedef bool (*llmodel_response_callback)(int32_t token_id, const char *response); -/** - * Callback type for recalculation of context. - * @param whether the model is recalculating the context. - * @return a bool indicating whether the model should keep generating. - */ -typedef bool (*llmodel_recalculate_callback)(bool is_recalculating); - /** * Embedding cancellation callback for use with llmodel_embed. * @param batch_sizes The number of tokens in each batch that will be embedded. @@ -175,7 +168,7 @@ uint64_t llmodel_restore_state_data(llmodel_model model, const uint8_t *src); * @param prompt_template A string representing the input prompt template. * @param prompt_callback A callback function for handling the processing of prompt. * @param response_callback A callback function for handling the generated response. - * @param recalculate_callback A callback function for handling recalculation requests. + * @param allow_context_shift Whether to allow shifting of context to make room for more input. * @param special True if special tokens in the prompt should be processed, false otherwise. * @param fake_reply A string to insert into context as the model's reply, or NULL to generate one. * @param ctx A pointer to the llmodel_prompt_context structure. @@ -184,7 +177,7 @@ void llmodel_prompt(llmodel_model model, const char *prompt, const char *prompt_template, llmodel_prompt_callback prompt_callback, llmodel_response_callback response_callback, - llmodel_recalculate_callback recalculate_callback, + bool allow_context_shift, llmodel_prompt_context *ctx, bool special, const char *fake_reply); diff --git a/gpt4all-backend/llmodel_shared.cpp b/gpt4all-backend/llmodel_shared.cpp index 68ea42e4..7477254a 100644 --- a/gpt4all-backend/llmodel_shared.cpp +++ b/gpt4all-backend/llmodel_shared.cpp @@ -11,42 +11,9 @@ #include #include #include -#include #include -// TODO(cebtenzzre): replace this with llama_kv_cache_seq_shift for llamamodel (GPT-J needs this as-is) -// FIXME(jared): if recalculate returns false, we leave n_past recalculate) -{ - int n_keep = shouldAddBOS(); - const int32_t n_discard = (promptCtx.n_ctx - n_keep) * promptCtx.contextErase; - - // Erase the first percentage of context from the tokens - std::cerr << implementation().modelType() << ": reached the end of the context window so resizing\n"; - promptCtx.tokens.erase(promptCtx.tokens.begin() + n_keep, promptCtx.tokens.begin() + n_keep + n_discard); - - size_t i = n_keep; - promptCtx.n_past = n_keep; - while (i < promptCtx.tokens.size()) { - size_t batch_end = std::min(i + promptCtx.n_batch, promptCtx.tokens.size()); - std::vector batch(promptCtx.tokens.begin() + i, promptCtx.tokens.begin() + batch_end); - assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); - if (!evalTokens(promptCtx, batch)) { - std::cerr << "LLModel ERROR: Failed to process prompt\n"; - goto stop_generating; - } - promptCtx.n_past += batch.size(); - if (!recalculate(true)) - goto stop_generating; - i = batch_end; - } - assert(promptCtx.n_past == int32_t(promptCtx.tokens.size())); - -stop_generating: - recalculate(false); -} +namespace ranges = std::ranges; static bool parsePromptTemplate(const std::string &tmpl, std::vector &placeholders, std::string &err) { @@ -75,7 +42,7 @@ void LLModel::prompt(const std::string &prompt, const std::string &promptTemplate, std::function promptCallback, std::function responseCallback, - std::function recalculateCallback, + bool allowContextShift, PromptContext &promptCtx, bool special, std::string *fakeReply) @@ -92,12 +59,21 @@ void LLModel::prompt(const std::string &prompt, return; } - // make sure token cache matches decode offset - if (promptCtx.tokens.size() < promptCtx.n_past) { + // sanity checks + if (promptCtx.n_past > contextLength()) { std::ostringstream ss; - ss << "expected n_past to be at most " << promptCtx.tokens.size() << ", got " << promptCtx.n_past; + ss << "n_past=" << promptCtx.n_past << " is past end of context length=" << contextLength(); throw std::out_of_range(ss.str()); } + if (promptCtx.n_past > promptCtx.tokens.size()) { + std::ostringstream ss; + ss << "n_past=" << promptCtx.n_past << " is past end of token cache length=" << promptCtx.tokens.size(); + throw std::out_of_range(ss.str()); + } + + promptCtx.n_ctx = contextLength(); + promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH); + if (promptCtx.n_past < promptCtx.tokens.size()) promptCtx.tokens.resize(promptCtx.n_past); m_tokenize_last_token = promptCtx.tokens.empty() ? -1 : promptCtx.tokens.back(); // not serialized @@ -149,15 +125,15 @@ void LLModel::prompt(const std::string &prompt, promptCtx.n_past = old_n_past; // restore n_past so decodePrompt can increment it // decode the user prompt - if (!decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp)) + if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp)) return; // error // decode the assistant's reply, either generated or spoofed if (fakeReply == nullptr) { - generateResponse(responseCallback, recalculateCallback, promptCtx); + generateResponse(responseCallback, allowContextShift, promptCtx); } else { embd_inp = tokenize(promptCtx, *fakeReply, false); - if (!decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp)) + if (!decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp)) return; // error } @@ -172,19 +148,16 @@ void LLModel::prompt(const std::string &prompt, } if (!asstSuffix.empty()) { embd_inp = tokenize(promptCtx, asstSuffix, true); - decodePrompt(promptCallback, responseCallback, recalculateCallback, promptCtx, embd_inp); + decodePrompt(promptCallback, responseCallback, allowContextShift, promptCtx, embd_inp); } } // returns false on error bool LLModel::decodePrompt(std::function promptCallback, std::function responseCallback, - std::function recalculateCallback, + bool allowContextShift, PromptContext &promptCtx, std::vector embd_inp) { - // save the context size - promptCtx.n_ctx = contextLength(); - if ((int) embd_inp.size() > promptCtx.n_ctx - 4) { responseCallback(-1, "ERROR: The prompt size exceeds the context window size and cannot be processed."); std::cerr << implementation().modelType() << " ERROR: The prompt is " << embd_inp.size() << @@ -192,9 +165,14 @@ bool LLModel::decodePrompt(std::function promptCallback, return false; } - promptCtx.n_predict = std::min(promptCtx.n_predict, promptCtx.n_ctx - (int) embd_inp.size()); - promptCtx.n_past = std::min(promptCtx.n_past, promptCtx.n_ctx); - promptCtx.n_batch = std::min(promptCtx.n_batch, LLMODEL_MAX_PROMPT_BATCH); + // FIXME(jared): There are mitigations for this situation, such as making room before + // copying the prompt context, or restoring the KV cache when we restore the prompt + // context. + if (!allowContextShift && promptCtx.n_past + embd_inp.size() > promptCtx.n_ctx) { + std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_eval=" << embd_inp.size() + << ", n_ctx=" << promptCtx.n_ctx << "\n"; + return false; + } // process the prompt in batches size_t i = 0; @@ -204,7 +182,8 @@ bool LLModel::decodePrompt(std::function promptCallback, // Check if the context has run out... if (promptCtx.n_past + int32_t(batch.size()) > promptCtx.n_ctx) { - recalculateContext(promptCtx, recalculateCallback); + assert(allowContextShift); + shiftContext(promptCtx); assert(promptCtx.n_past + int32_t(batch.size()) <= promptCtx.n_ctx); } @@ -226,70 +205,170 @@ bool LLModel::decodePrompt(std::function promptCallback, return true; } +/* + * If string s overlaps with the string key such that some prefix of the key is at the end + * of the string, return the position in s where the first match starts. Otherwise, return + * std::string::npos. Examples: + * s = "bfo", key = "foo" -> 1 + * s = "fooa", key = "foo" -> npos + */ +static std::string::size_type stringsOverlap(const std::string &s, const std::string &key) +{ + if (s.empty() || key.empty()) + throw std::invalid_argument("arguments to stringsOverlap must not be empty"); + + for (int start = std::max(0, int(s.size()) - int(key.size())); start < s.size(); start++) { + if (s.compare(start, s.size(), key, 0, s.size() - start) == 0) + return start; + } + return std::string::npos; +} + void LLModel::generateResponse(std::function responseCallback, - std::function recalculateCallback, + bool allowContextShift, PromptContext &promptCtx) { + static const char *stopSequences[] { + "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context", + }; + + // Don't even start if there is no room + if (!promptCtx.n_predict) + return; + if (!allowContextShift && promptCtx.n_past >= promptCtx.n_ctx) { + std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" << promptCtx.n_ctx + << "\n"; + return; + } + std::string cachedResponse; std::vector cachedTokens; - std::unordered_set reversePrompts - = { "### Instruction", "### Prompt", "### Response", "### Human", "### Assistant", "### Context" }; + int n_predicted = 0; - // predict next tokens - for (int i = 0; i < promptCtx.n_predict; i++) { + // Predict next tokens + for (bool stop = false; !stop;) { + // Sample next token + std::optional new_tok = sampleToken(promptCtx); + std::string new_piece = tokenToString(new_tok.value()); + cachedTokens.push_back(new_tok.value()); + cachedResponse += new_piece; - // sample next token - auto id = sampleToken(promptCtx); + auto accept = [this, &promptCtx, &cachedTokens, &new_tok, allowContextShift]() -> bool { + // Shift context if out of space + if (promptCtx.n_past >= promptCtx.n_ctx) { + (void)allowContextShift; + assert(allowContextShift); + shiftContext(promptCtx); + assert(promptCtx.n_past < promptCtx.n_ctx); + } - // Check if the context has run out... - if (promptCtx.n_past + 1 > promptCtx.n_ctx) { - recalculateContext(promptCtx, recalculateCallback); - assert(promptCtx.n_past + 1 <= promptCtx.n_ctx); - } + // Accept the token + Token tok = std::exchange(new_tok, std::nullopt).value(); + if (!evalTokens(promptCtx, { tok })) { + // TODO(jared): raise an exception + std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n"; + return false; + } - if (!evalTokens(promptCtx, { id })) { - std::cerr << implementation().modelType() << " ERROR: Failed to predict next token\n"; - return; - } + promptCtx.tokens.push_back(tok); + promptCtx.n_past += 1; + return true; + }; - // display text + // Check for EOS + auto lengthLimit = std::string::npos; for (const auto token : endTokens()) { - if (id == token) return; - } - - const std::string str = tokenToString(id); - - // Check if the provided str is part of our reverse prompts - bool foundPartialReversePrompt = false; - const std::string completed = cachedResponse + std::string(str); - if (reversePrompts.find(completed) != reversePrompts.end()) - return; - - // Check if it partially matches our reverse prompts and if so, cache - for (const auto& s : reversePrompts) { - if (s.compare(0, completed.size(), completed) == 0) { - foundPartialReversePrompt = true; - cachedResponse = completed; - break; + if (new_tok == token) { + stop = true; + lengthLimit = cachedResponse.size() - new_piece.size(); } } - // Regardless the token gets added to our cache - cachedTokens.push_back(id); + if (lengthLimit != std::string::npos) { + // EOS matched + } else if (!isSpecialToken(new_tok.value())) { + // Check if the response contains a stop sequence + for (const auto &p : stopSequences) { + auto match = cachedResponse.find(p); + if (match != std::string::npos) stop = true; + lengthLimit = std::min(lengthLimit, match); + if (match == 0) break; + } - // Continue if we have found a partial match - if (foundPartialReversePrompt) - continue; - - // Empty the cache - for (auto t : cachedTokens) { - promptCtx.tokens.push_back(t); - promptCtx.n_past += 1; - //TODO: Conversion to std::string can be avoided here... - if (!responseCallback(t, std::string(tokenToString(t)))) - return; + // Check if the response matches the start of a stop sequence + if (lengthLimit == std::string::npos) { + for (const auto &p : stopSequences) { + auto match = stringsOverlap(cachedResponse, p); + lengthLimit = std::min(lengthLimit, match); + if (match == 0) break; + } + } + } else if (ranges::contains(stopSequences, new_piece)) { + // Special tokens must exactly match a stop sequence + stop = true; + lengthLimit = cachedResponse.size() - new_piece.size(); + } + + // Optionally stop if the context will run out + if (!allowContextShift && promptCtx.n_past + cachedTokens.size() >= promptCtx.n_ctx) { + std::cerr << "LLModel Warning: Not enough space, n_past=" << promptCtx.n_past << ", n_ctx=" + << promptCtx.n_ctx << "\n"; + stop = true; + } + + // Empty the cache, up to the length limit + std::string::size_type responseLength = 0; + while (!cachedTokens.empty()) { + Token tok = cachedTokens.front(); + std::string piece = tokenToString(tok); + + // Stop if the piece (or part of it) does not fit within the length limit + if (responseLength + (stop ? 1 : piece.size()) > lengthLimit) + break; + + // Remove token from cache + assert(cachedResponse.starts_with(piece)); + cachedTokens.erase(cachedTokens.begin(), cachedTokens.begin() + 1); + cachedResponse.erase(cachedResponse.begin(), cachedResponse.begin() + piece.size()); + + // Accept the token, if needed (not cached) + if (cachedTokens.empty() && new_tok && !accept()) + return; + + // Send the token + if (!responseCallback(tok, piece) || ++n_predicted >= promptCtx.n_predict) { + stop = true; + break; + } + + // FIXME(jared): we could avoid printing partial stop sequences if we didn't have to + // output token IDs and could cache a partial token for the next prompt call + responseLength += piece.size(); + } + assert(cachedTokens.empty() == cachedResponse.empty()); + + // Accept the token, if needed (in cache) + if (new_tok) { + assert(!cachedTokens.empty() && cachedTokens.back() == new_tok); + if (stop) { + cachedTokens.pop_back(); + } else if (!accept()) { + return; + } } - cachedTokens.clear(); } + + auto &tokens = promptCtx.tokens; + if (tokens.size() < cachedTokens.size()) { + /* This is theoretically possible if the longest stop sequence is greater than + * n_ctx * contextErase tokens. */ + throw std::runtime_error("shifted too much context, can't go back"); + } + + auto discard_start = tokens.end() - cachedTokens.size(); + assert(std::equal(discard_start, tokens.end(), cachedTokens.begin())); + tokens.erase(discard_start, tokens.end()); + + promptCtx.n_past -= cachedTokens.size(); } void LLModel::embed( diff --git a/gpt4all-bindings/python/gpt4all/_pyllmodel.py b/gpt4all-bindings/python/gpt4all/_pyllmodel.py index 4c52a5b1..a4952fe3 100644 --- a/gpt4all-bindings/python/gpt4all/_pyllmodel.py +++ b/gpt4all-bindings/python/gpt4all/_pyllmodel.py @@ -128,7 +128,6 @@ llmodel.llmodel_isModelLoaded.restype = ctypes.c_bool PromptCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32) ResponseCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_int32, ctypes.c_char_p) -RecalculateCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.c_bool) EmbCancelCallback = ctypes.CFUNCTYPE(ctypes.c_bool, ctypes.POINTER(ctypes.c_uint), ctypes.c_uint, ctypes.c_char_p) llmodel.llmodel_prompt.argtypes = [ @@ -137,7 +136,7 @@ llmodel.llmodel_prompt.argtypes = [ ctypes.c_char_p, PromptCallback, ResponseCallback, - RecalculateCallback, + ctypes.c_bool, ctypes.POINTER(LLModelPromptContext), ctypes.c_bool, ctypes.c_char_p, @@ -513,7 +512,7 @@ class LLModel: ctypes.c_char_p(prompt_template.encode()), PromptCallback(self._prompt_callback), ResponseCallback(self._callback_decoder(callback)), - RecalculateCallback(self._recalculate_callback), + True, self.context, special, ctypes.c_char_p(), @@ -606,8 +605,3 @@ class LLModel: @staticmethod def _prompt_callback(token_id: int) -> bool: return True - - # Empty recalculate callback - @staticmethod - def _recalculate_callback(is_recalculating: bool) -> bool: - return is_recalculating diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index e19b1c4e..07acef15 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.16) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD 23) set(CMAKE_CXX_STANDARD_REQUIRED ON) if(APPLE) @@ -31,7 +31,6 @@ project(gpt4all VERSION ${APP_VERSION_BASE} LANGUAGES CXX C) set(CMAKE_AUTOMOC ON) set(CMAKE_AUTORCC ON) -set(CMAKE_CXX_STANDARD_REQUIRED ON) option(GPT4ALL_TRANSLATIONS OFF "Build with translations") option(GPT4ALL_LOCALHOST OFF "Build installer for localhost repo") diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 2747ff9b..a44022c0 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -62,7 +62,7 @@ void Chat::connectLLM() connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); @@ -252,9 +252,9 @@ void Chat::serverNewPromptResponsePair(const QString &prompt) m_chatModel->appendResponse("Response: ", prompt); } -bool Chat::isRecalc() const +bool Chat::restoringFromText() const { - return m_llmodel->isRecalc(); + return m_llmodel->restoringFromText(); } void Chat::unloadAndDeleteLater() @@ -320,10 +320,10 @@ void Chat::generatedQuestionFinished(const QString &question) emit generatedQuestionsChanged(); } -void Chat::handleRecalculating() +void Chat::handleRestoringFromText() { Network::globalInstance()->trackChatEvent("recalc_context", { {"length", m_chatModel->count()} }); - emit recalcChanged(); + emit restoringFromTextChanged(); } void Chat::handleModelLoadingError(const QString &error) diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index c9b95f55..065c624e 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -27,7 +27,7 @@ class Chat : public QObject Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged) Q_PROPERTY(bool responseInProgress READ responseInProgress NOTIFY responseInProgressChanged) - Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) + Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged) Q_PROPERTY(bool isServer READ isServer NOTIFY isServerChanged) Q_PROPERTY(ResponseState responseState READ responseState NOTIFY responseStateChanged) Q_PROPERTY(QList collectionList READ collectionList NOTIFY collectionListChanged) @@ -88,7 +88,7 @@ public: ResponseState responseState() const; ModelInfo modelInfo() const; void setModelInfo(const ModelInfo &modelInfo); - bool isRecalc() const; + bool restoringFromText() const; Q_INVOKABLE void unloadModel(); Q_INVOKABLE void reloadModel(); @@ -144,7 +144,7 @@ Q_SIGNALS: void processSystemPromptRequested(); void modelChangeRequested(const ModelInfo &modelInfo); void modelInfoChanged(); - void recalcChanged(); + void restoringFromTextChanged(); void loadDefaultModelRequested(); void loadModelRequested(const ModelInfo &modelInfo); void generateNameRequested(); @@ -167,7 +167,7 @@ private Q_SLOTS: void responseStopped(qint64 promptResponseMs); void generatedNameChanged(const QString &name); void generatedQuestionFinished(const QString &question); - void handleRecalculating(); + void handleRestoringFromText(); void handleModelLoadingError(const QString &error); void handleTokenSpeedChanged(const QString &tokenSpeed); void handleDatabaseResultsChanged(const QList &results); diff --git a/gpt4all-chat/chatapi.cpp b/gpt4all-chat/chatapi.cpp index 1cf94173..b443f24c 100644 --- a/gpt4all-chat/chatapi.cpp +++ b/gpt4all-chat/chatapi.cpp @@ -90,13 +90,13 @@ void ChatAPI::prompt(const std::string &prompt, const std::string &promptTemplate, std::function promptCallback, std::function responseCallback, - std::function recalculateCallback, + bool allowContextShift, PromptContext &promptCtx, bool special, std::string *fakeReply) { Q_UNUSED(promptCallback); - Q_UNUSED(recalculateCallback); + Q_UNUSED(allowContextShift); Q_UNUSED(special); if (!isModelLoaded()) { diff --git a/gpt4all-chat/chatapi.h b/gpt4all-chat/chatapi.h index 51ba8067..59b68f58 100644 --- a/gpt4all-chat/chatapi.h +++ b/gpt4all-chat/chatapi.h @@ -69,7 +69,7 @@ public: const std::string &promptTemplate, std::function promptCallback, std::function responseCallback, - std::function recalculateCallback, + bool allowContextShift, PromptContext &ctx, bool special, std::string *fakeReply) override; @@ -97,38 +97,57 @@ protected: // them as they are only called from the default implementation of 'prompt' which we override and // completely replace - std::vector tokenize(PromptContext &ctx, const std::string &str, bool special) override { + std::vector tokenize(PromptContext &ctx, const std::string &str, bool special) override + { (void)ctx; (void)str; (void)special; throw std::logic_error("not implemented"); } - std::string tokenToString(Token id) const override { + bool isSpecialToken(Token id) const override + { (void)id; throw std::logic_error("not implemented"); } - Token sampleToken(PromptContext &ctx) const override { + std::string tokenToString(Token id) const override + { + (void)id; + throw std::logic_error("not implemented"); + } + + Token sampleToken(PromptContext &ctx) const override + { (void)ctx; throw std::logic_error("not implemented"); } - bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override { + bool evalTokens(PromptContext &ctx, const std::vector &tokens) const override + { (void)ctx; (void)tokens; throw std::logic_error("not implemented"); } - int32_t contextLength() const override { + void shiftContext(PromptContext &promptCtx) override + { + (void)promptCtx; throw std::logic_error("not implemented"); } - const std::vector &endTokens() const override { + int32_t contextLength() const override + { throw std::logic_error("not implemented"); } - bool shouldAddBOS() const override { + const std::vector &endTokens() const override + { + throw std::logic_error("not implemented"); + } + + bool shouldAddBOS() const override + { throw std::logic_error("not implemented"); } diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 18319cee..e9fb7f31 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -102,7 +102,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) : QObject{nullptr} , m_promptResponseTokens(0) , m_promptTokens(0) - , m_isRecalc(false) + , m_restoringFromText(false) , m_shouldBeLoaded(false) , m_forceUnloadModel(false) , m_markedForDeletion(false) @@ -712,17 +712,6 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response) return !m_stopGenerating; } -bool ChatLLM::handleRecalculate(bool isRecalc) -{ -#if defined(DEBUG) - qDebug() << "recalculate" << m_llmThread.objectName() << isRecalc; -#endif - if (m_isRecalc != isRecalc) { - m_isRecalc = isRecalc; - emit recalcChanged(); - } - return !m_stopGenerating; -} bool ChatLLM::prompt(const QList &collectionList, const QString &prompt) { if (m_restoreStateFromText) { @@ -776,7 +765,6 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1); auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1, std::placeholders::_2); - auto recalcFunc = std::bind(&ChatLLM::handleRecalculate, this, std::placeholders::_1); emit promptProcessing(); m_ctx.n_predict = n_predict; m_ctx.top_k = top_k; @@ -796,10 +784,12 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString m_timer->start(); if (!docsContext.isEmpty()) { auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode localdocs context without a response - m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc, recalcFunc, m_ctx); + m_llModelInfo.model->prompt(docsContext.toStdString(), "%1", promptFunc, responseFunc, + /*allowContextShift*/ true, m_ctx); m_ctx.n_predict = old_n_predict; // now we are ready for a response } - m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx); + m_llModelInfo.model->prompt(prompt.toStdString(), promptTemplate.toStdString(), promptFunc, responseFunc, + /*allowContextShift*/ true, m_ctx); #if defined(DEBUG) printf("\n"); fflush(stdout); @@ -904,10 +894,9 @@ void ChatLLM::generateName() auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1); auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2); - auto recalcFunc = std::bind(&ChatLLM::handleNameRecalculate, this, std::placeholders::_1); LLModel::PromptContext ctx = m_ctx; m_llModelInfo.model->prompt(chatNamePrompt.toStdString(), promptTemplate.toStdString(), - promptFunc, responseFunc, recalcFunc, ctx); + promptFunc, responseFunc, /*allowContextShift*/ false, ctx); std::string trimmed = trim_whitespace(m_nameResponse); if (trimmed != m_nameResponse) { m_nameResponse = trimmed; @@ -944,15 +933,6 @@ bool ChatLLM::handleNameResponse(int32_t token, const std::string &response) return words.size() <= 3; } -bool ChatLLM::handleNameRecalculate(bool isRecalc) -{ -#if defined(DEBUG) - qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc; -#endif - Q_UNUSED(isRecalc); - return true; -} - bool ChatLLM::handleQuestionPrompt(int32_t token) { #if defined(DEBUG) @@ -991,15 +971,6 @@ bool ChatLLM::handleQuestionResponse(int32_t token, const std::string &response) return true; } -bool ChatLLM::handleQuestionRecalculate(bool isRecalc) -{ -#if defined(DEBUG) - qDebug() << "name recalc" << m_llmThread.objectName() << isRecalc; -#endif - Q_UNUSED(isRecalc); - return true; -} - void ChatLLM::generateQuestions(qint64 elapsed) { Q_ASSERT(isModelLoaded()); @@ -1019,12 +990,11 @@ void ChatLLM::generateQuestions(qint64 elapsed) auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); auto promptFunc = std::bind(&ChatLLM::handleQuestionPrompt, this, std::placeholders::_1); auto responseFunc = std::bind(&ChatLLM::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2); - auto recalcFunc = std::bind(&ChatLLM::handleQuestionRecalculate, this, std::placeholders::_1); LLModel::PromptContext ctx = m_ctx; QElapsedTimer totalTime; totalTime.start(); - m_llModelInfo.model->prompt(suggestedFollowUpPrompt, - promptTemplate.toStdString(), promptFunc, responseFunc, recalcFunc, ctx); + m_llModelInfo.model->prompt(suggestedFollowUpPrompt, promptTemplate.toStdString(), promptFunc, responseFunc, + /*allowContextShift*/ false, ctx); elapsed += totalTime.elapsed(); emit responseStopped(elapsed); } @@ -1039,15 +1009,6 @@ bool ChatLLM::handleSystemPrompt(int32_t token) return !m_stopGenerating; } -bool ChatLLM::handleSystemRecalculate(bool isRecalc) -{ -#if defined(DEBUG) - qDebug() << "system recalc" << m_llmThread.objectName() << isRecalc; -#endif - Q_UNUSED(isRecalc); - return false; -} - bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token) { #if defined(DEBUG) @@ -1057,15 +1018,6 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token) return !m_stopGenerating; } -bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc) -{ -#if defined(DEBUG) - qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc; -#endif - Q_UNUSED(isRecalc); - return false; -} - // this function serialized the cached model state to disk. // we want to also serialize n_ctx, and read it at load time. bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) @@ -1268,7 +1220,6 @@ void ChatLLM::processSystemPrompt() m_ctx = LLModel::PromptContext(); auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1); - auto recalcFunc = std::bind(&ChatLLM::handleSystemRecalculate, this, std::placeholders::_1); const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); @@ -1294,7 +1245,7 @@ void ChatLLM::processSystemPrompt() #endif auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response // use "%1%2" and not "%1" to avoid implicit whitespace - m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, recalcFunc, m_ctx, true); + m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true); m_ctx.n_predict = old_n_predict; #if defined(DEBUG) printf("\n"); @@ -1311,14 +1262,13 @@ void ChatLLM::processRestoreStateFromText() if (!isModelLoaded() || !m_restoreStateFromText || m_isServer) return; - m_isRecalc = true; - emit recalcChanged(); + m_restoringFromText = true; + emit restoringFromTextChanged(); m_stopGenerating = false; m_ctx = LLModel::PromptContext(); auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1); - auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1); const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); @@ -1351,7 +1301,7 @@ void ChatLLM::processRestoreStateFromText() auto responseText = response.second.toStdString(); m_llModelInfo.model->prompt(prompt.second.toStdString(), promptTemplate.toStdString(), promptFunc, nullptr, - recalcFunc, m_ctx, false, &responseText); + /*allowContextShift*/ true, m_ctx, false, &responseText); } if (!m_stopGenerating) { @@ -1359,8 +1309,8 @@ void ChatLLM::processRestoreStateFromText() m_stateFromText.clear(); } - m_isRecalc = false; - emit recalcChanged(); + m_restoringFromText = false; + emit restoringFromTextChanged(); m_pristineLoadedState = false; } diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 252b77bd..d123358a 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -93,7 +93,7 @@ class Chat; class ChatLLM : public QObject { Q_OBJECT - Q_PROPERTY(bool isRecalc READ isRecalc NOTIFY recalcChanged) + Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged) Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged) @@ -121,7 +121,7 @@ public: ModelInfo modelInfo() const; void setModelInfo(const ModelInfo &info); - bool isRecalc() const { return m_isRecalc; } + bool restoringFromText() const { return m_restoringFromText; } void acquireModel(); void resetModel(); @@ -172,7 +172,7 @@ public Q_SLOTS: void processRestoreStateFromText(); Q_SIGNALS: - void recalcChanged(); + void restoringFromTextChanged(); void loadedModelInfoChanged(); void modelLoadingPercentageChanged(float); void modelLoadingError(const QString &error); @@ -201,19 +201,14 @@ protected: int32_t repeat_penalty_tokens); bool handlePrompt(int32_t token); bool handleResponse(int32_t token, const std::string &response); - bool handleRecalculate(bool isRecalc); bool handleNamePrompt(int32_t token); bool handleNameResponse(int32_t token, const std::string &response); - bool handleNameRecalculate(bool isRecalc); bool handleSystemPrompt(int32_t token); bool handleSystemResponse(int32_t token, const std::string &response); - bool handleSystemRecalculate(bool isRecalc); bool handleRestoreStateFromTextPrompt(int32_t token); bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response); - bool handleRestoreStateFromTextRecalculate(bool isRecalc); bool handleQuestionPrompt(int32_t token); bool handleQuestionResponse(int32_t token, const std::string &response); - bool handleQuestionRecalculate(bool isRecalc); void saveState(); void restoreState(); @@ -236,7 +231,7 @@ private: QThread m_llmThread; std::atomic m_stopGenerating; std::atomic m_shouldBeLoaded; - std::atomic m_isRecalc; + std::atomic m_restoringFromText; // status indication std::atomic m_forceUnloadModel; std::atomic m_markedForDeletion; bool m_isServer; diff --git a/gpt4all-chat/qml/ChatView.qml b/gpt4all-chat/qml/ChatView.qml index 2be8d75e..920e2759 100644 --- a/gpt4all-chat/qml/ChatView.qml +++ b/gpt4all-chat/qml/ChatView.qml @@ -834,7 +834,7 @@ Rectangle { to: 360 duration: 1000 loops: Animation.Infinite - running: currentResponse && (currentChat.responseInProgress || currentChat.isRecalc) + running: currentResponse && (currentChat.responseInProgress || currentChat.restoringFromText) } } } @@ -867,13 +867,13 @@ Rectangle { color: theme.mutedTextColor } RowLayout { - visible: currentResponse && ((value === "" && currentChat.responseInProgress) || currentChat.isRecalc) + visible: currentResponse && ((value === "" && currentChat.responseInProgress) || currentChat.restoringFromText) Text { color: theme.mutedTextColor font.pixelSize: theme.fontSizeLarger text: { - if (currentChat.isRecalc) - return qsTr("recalculating context ..."); + if (currentChat.restoringFromText) + return qsTr("restoring from text ..."); switch (currentChat.responseState) { case Chat.ResponseStopped: return qsTr("response stopped ..."); case Chat.LocalDocsRetrieval: return qsTr("retrieving localdocs: %1 ...").arg(currentChat.collectionList.join(", ")); @@ -1861,7 +1861,7 @@ Rectangle { } } function sendMessage() { - if (textInput.text === "" || currentChat.responseInProgress || currentChat.isRecalc) + if (textInput.text === "" || currentChat.responseInProgress || currentChat.restoringFromText) return currentChat.stopGenerating()