Fix loaded chats forgetting context with non-empty system prompt (#3015)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-10-01 11:25:04 -04:00 committed by GitHub
parent 3025f9deff
commit 88b95950c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 33 deletions

View File

@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Fix a crash when attempting to continue a chat loaded from disk ([#2995](https://github.com/nomic-ai/gpt4all/pull/2995))
- Fix the local server rejecting min\_p/top\_p less than 1 ([#2996](https://github.com/nomic-ai/gpt4all/pull/2996))
- Fix "regenerate" always forgetting the most recent message ([#3011](https://github.com/nomic-ai/gpt4all/pull/3011))
- Fix loaded chats forgetting context when there is a system prompt ([#3015](https://github.com/nomic-ai/gpt4all/pull/3015))
## [3.3.1] - 2024-09-27 ([v3.3.y](https://github.com/nomic-ai/gpt4all/tree/v3.3.y))

View File

@ -1230,51 +1230,49 @@ void ChatLLM::restoreState()
void ChatLLM::processSystemPrompt()
{
Q_ASSERT(isModelLoaded());
if (!isModelLoaded() || m_processedSystemPrompt || m_restoreStateFromText)
if (!isModelLoaded() || m_processedSystemPrompt)
return;
const std::string systemPrompt = MySettings::globalInstance()->modelSystemPrompt(m_modelInfo).toStdString();
if (QString::fromStdString(systemPrompt).trimmed().isEmpty()) {
m_processedSystemPrompt = true;
return;
}
// Start with a whole new context
m_stopGenerating = false;
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
if (!QString::fromStdString(systemPrompt).trimmed().isEmpty()) {
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
int n_threads = MySettings::globalInstance()->threadCount();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
m_ctx.min_p = min_p;
m_ctx.temp = temp;
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float min_p = MySettings::globalInstance()->modelMinP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
int n_threads = MySettings::globalInstance()->threadCount();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
m_ctx.min_p = min_p;
m_ctx.temp = temp;
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
#if defined(DEBUG)
printf("%s", qPrintable(QString::fromStdString(systemPrompt)));
fflush(stdout);
printf("%s", qPrintable(QString::fromStdString(systemPrompt)));
fflush(stdout);
#endif
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
// use "%1%2" and not "%1" to avoid implicit whitespace
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true);
m_ctx.n_predict = old_n_predict;
auto old_n_predict = std::exchange(m_ctx.n_predict, 0); // decode system prompt without a response
// use "%1%2" and not "%1" to avoid implicit whitespace
m_llModelInfo.model->prompt(systemPrompt, "%1%2", promptFunc, nullptr, /*allowContextShift*/ true, m_ctx, true);
m_ctx.n_predict = old_n_predict;
#if defined(DEBUG)
printf("\n");
fflush(stdout);
printf("\n");
fflush(stdout);
#endif
}
m_processedSystemPrompt = m_stopGenerating == false;
m_pristineLoadedState = false;
@ -1286,11 +1284,12 @@ void ChatLLM::processRestoreStateFromText()
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
return;
processSystemPrompt();
m_restoringFromText = true;
emit restoringFromTextChanged();
m_stopGenerating = false;
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);