diff --git a/gpt4all-chat/CHANGELOG.md b/gpt4all-chat/CHANGELOG.md index 2e0ab50a..82ded06a 100644 --- a/gpt4all-chat/CHANGELOG.md +++ b/gpt4all-chat/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - Fix a crash when attempting to continue a chat loaded from disk ([#2995](https://github.com/nomic-ai/gpt4all/pull/2995)) - Fix the local server rejecting min\_p/top\_p less than 1 ([#2996](https://github.com/nomic-ai/gpt4all/pull/2996)) - Fix "regenerate" always forgetting the most recent message ([#3011](https://github.com/nomic-ai/gpt4all/pull/3011)) +- Fix loaded chats forgetting context when there is a system prompt ([#3015](https://github.com/nomic-ai/gpt4all/pull/3015)) ## [3.3.1] - 2024-09-27 ([v3.3.y](https://github.com/nomic-ai/gpt4all/tree/v3.3.y)) diff --git a/gpt4all-chat/src/chatllm.cpp b/gpt4all-chat/src/chatllm.cpp index aded57b0..abbdf322 100644 --- a/gpt4all-chat/src/chatllm.cpp +++ b/gpt4all-chat/src/chatllm.cpp @@ -1230,51 +1230,49 @@ void ChatLLM::restoreState() void ChatLLM::processSystemPrompt() { Q_ASSERT(isModelLoaded()); - if (!isModelLoaded() || m_processedSystemPrompt || m_restoreStateFromText) + if (!isModelLoaded() || m_processedSystemPrompt) return; const std::string systemPrompt = MySettings::globalInstance()->modelSystemPrompt(m_modelInfo).toStdString(); - if (QString::fromStdString(systemPrompt).trimmed().isEmpty()) { - m_processedSystemPrompt = true; - return; - } // Start with a whole new context m_stopGenerating = false; m_ctx = LLModel::PromptContext(); - auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1); + if (!QString::fromStdString(systemPrompt).trimmed().isEmpty()) { + auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1); - const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); - const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); - const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); - const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo); - const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); - const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); - const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); - const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo); - int n_threads = MySettings::globalInstance()->threadCount(); - m_ctx.n_predict = n_predict; - m_ctx.top_k = top_k; - m_ctx.top_p = top_p; - m_ctx.min_p = min_p; - m_ctx.temp = temp; - m_ctx.n_batch = n_batch; - m_ctx.repeat_penalty = repeat_penalty; - m_ctx.repeat_last_n = repeat_penalty_tokens; - m_llModelInfo.model->setThreadCount(n_threads); + const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); + const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); + const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); + const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo); + const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); + const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); + const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); + const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo); + int n_threads = MySettings::globalInstance()->threadCount(); + m_ctx.n_predict = n_predict; + m_ctx.top_k = top_k; + m_ctx.top_p = top_p; + m_ctx.min_p = min_p; + m_ctx.temp = temp; + m_ctx.n_batch = n_batch; + m_ctx.repeat_penalty = repeat_penalty; + m_ctx.repeat_last_n = repeat_penalty_tokens; + m_llModelInfo.model->setThreadCount(n_threads); #if defined(DEBUG) - printf("%s", qPrintable(QString::fromStdString(systemPrompt))); - fflush(stdout); + printf("%s", qPrintable(QString::fromStdString(systemPrompt))); + fflush(stdout); #endif - auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response - // use "%1%2" and not "%1" to avoid implicit whitespace - m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true); - m_ctx.n_predict = old_n_predict; + auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response + // use "%1%2" and not "%1" to avoid implicit whitespace + m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true); + m_ctx.n_predict = old_n_predict; #if defined(DEBUG) - printf("\n"); - fflush(stdout); + printf("\n"); + fflush(stdout); #endif + } m_processedSystemPrompt = m_stopGenerating == false; m_pristineLoadedState = false; @@ -1286,11 +1284,12 @@ void ChatLLM::processRestoreStateFromText() if (!isModelLoaded() || !m_restoreStateFromText || m_isServer) return; + processSystemPrompt(); + m_restoringFromText = true; emit restoringFromTextChanged(); m_stopGenerating = false; - m_ctx = LLModel::PromptContext(); auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);