From dda59a97a64e5290171ebad0275d68c1518c4478 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Sat, 27 Jul 2024 11:11:41 -0400 Subject: [PATCH] Fix the way we're injecting the context back into the model for web search. Signed-off-by: Adam Treat --- gpt4all-chat/chat.cpp | 2 ++ gpt4all-chat/chatllm.cpp | 6 +++--- gpt4all-chat/chatmodel.h | 9 +++++++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 7da8274c..2730eaf3 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -116,6 +116,7 @@ void Chat::resetResponseState() if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval) return; + m_sourceExcerpts = QList(); m_generatedQuestions = QList(); emit generatedQuestionsChanged(); m_tokenSpeed = QString(); @@ -136,6 +137,7 @@ void Chat::prompt(const QString &prompt) void Chat::regenerateResponse() { const int index = m_chatModel->count() - 1; + m_sourceExcerpts = QList(); m_chatModel->updateSources(index, QList()); emit regenerateResponseRequested(); } diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index a6dea395..073a1d7a 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -869,13 +869,13 @@ bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32 static QRegularExpression re(R"(brave_search\.call\(query=\"([^\"]+)\"\))"); QRegularExpressionMatch match = re.match(toolCall); - QString prompt("<|start_header_id|>ipython<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2"); + QString promptTemplate("<|start_header_id|>ipython<|end_header_id|>\n\n%1<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n%2"); QString query; if (match.hasMatch()) { query = match.captured(1); } else { qWarning() << "WARNING: Could not find the tool for " << toolCall; - return promptInternal(QList()/*collectionList*/, prompt.arg(QString()), QString("%1") /*promptTemplate*/, + return promptInternal(QList()/*collectionList*/, QString() /*prompt*/, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); } @@ -887,7 +887,7 @@ bool ChatLLM::toolCallInternal(const QString &toolCall, int32_t n_predict, int32 emit sourceExcerptsChanged(braveResponse.second); - return promptInternal(QList()/*collectionList*/, prompt.arg(braveResponse.first), QString("%1") /*promptTemplate*/, + return promptInternal(QList()/*collectionList*/, braveResponse.first, promptTemplate, n_predict, top_k, top_p, min_p, temp, n_batch, repeat_penalty, repeat_penalty_tokens); } diff --git a/gpt4all-chat/chatmodel.h b/gpt4all-chat/chatmodel.h index 1031ded5..97b81275 100644 --- a/gpt4all-chat/chatmodel.h +++ b/gpt4all-chat/chatmodel.h @@ -219,8 +219,13 @@ public: if (index < 0 || index >= m_chatItems.size()) return; ChatItem &item = m_chatItems[index]; - item.sources = sources; - item.consolidatedSources = consolidateSources(sources); + if (sources.isEmpty()) { + item.sources.clear(); + item.consolidatedSources.clear(); + } else { + item.sources << sources; + item.consolidatedSources << consolidateSources(sources); + } emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole}); emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole}); }