From f67b370f5a57b8215041024a98292f372c0e15a2 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Fri, 13 Dec 2024 12:19:47 -0500 Subject: [PATCH] Fix local server regressions caused by Jinja PR (#3256) Signed-off-by: Jared Van Bortel --- gpt4all-chat/CHANGELOG.md | 8 +++ gpt4all-chat/src/chatllm.cpp | 57 ++++++++------- gpt4all-chat/src/chatllm.h | 3 +- gpt4all-chat/src/chatmodel.h | 74 +++++++++++--------- gpt4all-chat/src/server.cpp | 10 +-- gpt4all-chat/tests/python/test_server_api.py | 2 +- 6 files changed, 88 insertions(+), 66 deletions(-) diff --git a/gpt4all-chat/CHANGELOG.md b/gpt4all-chat/CHANGELOG.md index cc880c67..3a455d10 100644 --- a/gpt4all-chat/CHANGELOG.md +++ b/gpt4all-chat/CHANGELOG.md @@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). +## [Unreleased] + +### Fixed +- Fix API server ignoring assistant messages in history after v3.5.0 ([#3256](https://github.com/nomic-ai/gpt4all/pull/3256)) +- Fix API server replying with incorrect token counts and stop reason after v3.5.0 ([#3256](https://github.com/nomic-ai/gpt4all/pull/3256)) +- Fix API server remembering previous, unrelated conversations after v3.5.0 ([#3256](https://github.com/nomic-ai/gpt4all/pull/3256)) + ## [3.5.1] - 2024-12-10 ### Fixed @@ -211,6 +218,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - Fix several Vulkan resource management issues ([#2694](https://github.com/nomic-ai/gpt4all/pull/2694)) - Fix crash/hang when some models stop generating, by showing special tokens ([#2701](https://github.com/nomic-ai/gpt4all/pull/2701)) +[Unreleased]: https://github.com/nomic-ai/gpt4all/compare/v3.5.1...HEAD [3.5.1]: https://github.com/nomic-ai/gpt4all/compare/v3.5.0...v3.5.1 [3.5.0]: https://github.com/nomic-ai/gpt4all/compare/v3.5.0-rc2...v3.5.0 [3.5.0-rc2]: https://github.com/nomic-ai/gpt4all/compare/v3.5.0-rc1...v3.5.0-rc2 diff --git a/gpt4all-chat/src/chatllm.cpp b/gpt4all-chat/src/chatllm.cpp index f575ac2d..8d2f11ad 100644 --- a/gpt4all-chat/src/chatllm.cpp +++ b/gpt4all-chat/src/chatllm.cpp @@ -851,39 +851,44 @@ std::string ChatLLM::applyJinjaTemplate(std::span items) const return *maybeRendered; } -auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx) - -> ChatPromptResult +auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx, + std::optional> chat) -> ChatPromptResult { Q_ASSERT(isModelLoaded()); Q_ASSERT(m_chatModel); - QList databaseResults; - const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize(); - if (!enabledCollections.isEmpty()) { - std::optional> query; - { - // Find the prompt that represents the query. Server chats are flexible and may not have one. - auto items = m_chatModel->chatItems(); // holds lock - Q_ASSERT(items); - auto response = items.end() - 1; - if (auto peer = m_chatModel->getPeerUnlocked(response)) - query = {*peer - items.begin(), (*peer)->value}; - } - if (query) { - auto &[promptIndex, queryStr] = *query; - emit requestRetrieveFromDB(enabledCollections, queryStr, retrievalSize, &databaseResults); // blocks - m_chatModel->updateSources(promptIndex, databaseResults); - emit databaseResultsChanged(databaseResults); - } - } - // copy messages for safety (since we can't hold the lock the whole time) + std::optional> query; std::vector chatItems; { - auto items = m_chatModel->chatItems(); // holds lock - Q_ASSERT(items.size() >= 2); // should be prompt/response pairs - chatItems.assign(items.begin(), items.end() - 1); // exclude last + std::optional items; + std::span view; + if (chat) { + view = *chat; + } else { + items = m_chatModel->chatItems(); // holds lock + Q_ASSERT(!items->empty()); + view = *items; + } + Q_ASSERT(view.size() >= 2); // should be prompt/response pairs + + // Find the prompt that represents the query. Server chats are flexible and may not have one. + auto response = view.end() - 1; + if (auto peer = m_chatModel->getPeer(view, response)) + query = { *peer - view.begin(), (*peer)->value }; + + chatItems.assign(view.begin(), view.end() - 1); // exclude last } + + QList databaseResults; + if (query && !enabledCollections.isEmpty()) { + auto &[promptIndex, queryStr] = *query; + const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize(); + emit requestRetrieveFromDB(enabledCollections, queryStr, retrievalSize, &databaseResults); // blocks + m_chatModel->updateSources(promptIndex, databaseResults); + emit databaseResultsChanged(databaseResults); + } + auto result = promptInternal(chatItems, ctx, !databaseResults.isEmpty()); return { /*PromptResult*/ { @@ -933,7 +938,7 @@ auto ChatLLM::promptInternal( } } - PromptResult result; + PromptResult result {}; auto handlePrompt = [this, &result](std::span batch, bool cached) -> bool { Q_UNUSED(cached) diff --git a/gpt4all-chat/src/chatllm.h b/gpt4all-chat/src/chatllm.h index c79ca0bd..57905286 100644 --- a/gpt4all-chat/src/chatllm.h +++ b/gpt4all-chat/src/chatllm.h @@ -250,7 +250,8 @@ protected: QList databaseResults; }; - ChatPromptResult promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx); + ChatPromptResult promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx, + std::optional> chat = {}); // passing a string_view directly skips templating and uses the raw string PromptResult promptInternal(const std::variant, std::string_view> &prompt, const LLModel::PromptContext &ctx, diff --git a/gpt4all-chat/src/chatmodel.h b/gpt4all-chat/src/chatmodel.h index 1ac5843c..036150dc 100644 --- a/gpt4all-chat/src/chatmodel.h +++ b/gpt4all-chat/src/chatmodel.h @@ -111,8 +111,13 @@ public: ChatItem(prompt_tag_t, const QString &value, const QList &attachments = {}) : name(u"Prompt: "_s), value(value), promptAttachments(attachments) {} - ChatItem(response_tag_t, bool isCurrentResponse = true) - : name(u"Response: "_s), isCurrentResponse(isCurrentResponse) {} + // A new response, to be filled in + ChatItem(response_tag_t) + : name(u"Response: "_s), isCurrentResponse(true) {} + + // An existing response, from Server + ChatItem(response_tag_t, const QString &value) + : name(u"Response: "_s), value(value) {} Type type() const { @@ -165,9 +170,9 @@ public: }; Q_DECLARE_METATYPE(ChatItem) -class ChatModelAccessor : public ranges::subrange::const_iterator> { +class ChatModelAccessor : public std::span { private: - using Super = ranges::subrange::const_iterator>; + using Super = std::span; public: template @@ -219,38 +224,38 @@ public: /* a "peer" is a bidirectional 1:1 link between a prompt and the response that would cite its LocalDocs * sources. Return std::nullopt if there is none, which is possible for e.g. server chats. */ - auto getPeerUnlocked(QList::const_iterator item) const - -> std::optional::const_iterator> + static std::optional getPeer(const ChatItem *arr, qsizetype size, qsizetype index) { - switch (item->type()) { + Q_ASSERT(index >= 0); + Q_ASSERT(index < size); + qsizetype peer; + ChatItem::Type expected; + switch (arr[index].type()) { using enum ChatItem::Type; - case Prompt: - { - auto peer = std::next(item); - if (peer < m_chatItems.cend() && peer->type() == Response) - return peer; - break; - } - case Response: - { - if (item > m_chatItems.cbegin()) { - if (auto peer = std::prev(item); peer->type() == Prompt) - return peer; - } - break; - } - default: - throw std::invalid_argument("getPeer() called on item that is not a prompt or response"); + case Prompt: peer = index + 1; expected = Response; break; + case Response: peer = index - 1; expected = Prompt; break; + default: throw std::invalid_argument("getPeer() called on item that is not a prompt or response"); } + if (peer >= 0 && peer < size && arr[peer].type() == expected) + return peer; return std::nullopt; } - auto getPeerUnlocked(int index) const -> std::optional + template + requires std::same_as, ChatItem> + static auto getPeer(R &&range, ranges::iterator_t item) -> std::optional> { - return getPeerUnlocked(m_chatItems.cbegin() + index) - .transform([&](auto &&i) { return i - m_chatItems.cbegin(); } ); + auto begin = ranges::begin(range); + return getPeer(ranges::data(range), ranges::size(range), item - begin) + .transform([&](auto i) { return begin + i; }); } + auto getPeerUnlocked(QList::const_iterator item) const -> std::optional::const_iterator> + { return getPeer(m_chatItems, item); } + + std::optional getPeerUnlocked(qsizetype index) const + { return getPeer(m_chatItems.constData(), m_chatItems.size(), index); } + QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override { QMutexLocker locker(&m_mutex); @@ -356,26 +361,29 @@ public: } // Used by Server to append a new conversation to the chat log. - void appendResponseWithHistory(std::span history) + // Appends a new, blank response to the end of the input list. + void appendResponseWithHistory(QList &history) { if (history.empty()) throw std::invalid_argument("at least one message is required"); + // add an empty response to prepare for generation + history.emplace_back(ChatItem::response_tag); + m_mutex.lock(); - qsizetype startIndex = m_chatItems.count(); + qsizetype startIndex = m_chatItems.size(); m_mutex.unlock(); - qsizetype nNewItems = history.size() + 1; - qsizetype endIndex = startIndex + nNewItems; + qsizetype endIndex = startIndex + history.size(); beginInsertRows(QModelIndex(), startIndex, endIndex - 1 /*inclusive*/); bool hadError; + QList newItems; { QMutexLocker locker(&m_mutex); hadError = hasErrorUnlocked(); - m_chatItems.reserve(m_chatItems.count() + nNewItems); + m_chatItems.reserve(m_chatItems.size() + history.size()); for (auto &item : history) m_chatItems << item; - m_chatItems.emplace_back(ChatItem::response_tag); } endInsertRows(); emit countChanged(); diff --git a/gpt4all-chat/src/server.cpp b/gpt4all-chat/src/server.cpp index 4c48dcbd..91e17de7 100644 --- a/gpt4all-chat/src/server.cpp +++ b/gpt4all-chat/src/server.cpp @@ -771,13 +771,13 @@ auto Server::handleChatRequest(const ChatRequest &request) Q_ASSERT(!request.messages.isEmpty()); // adds prompt/response items to GUI - std::vector chatItems; + QList chatItems; for (auto &message : request.messages) { using enum ChatRequest::Message::Role; switch (message.role) { - case System: chatItems.emplace_back(ChatItem::system_tag, message.content); break; - case User: chatItems.emplace_back(ChatItem::prompt_tag, message.content); break; - case Assistant: chatItems.emplace_back(ChatItem::response_tag, /*currentResponse*/ false); break; + case System: chatItems.emplace_back(ChatItem::system_tag, message.content); break; + case User: chatItems.emplace_back(ChatItem::prompt_tag, message.content); break; + case Assistant: chatItems.emplace_back(ChatItem::response_tag, message.content); break; } } m_chatModel->appendResponseWithHistory(chatItems); @@ -800,7 +800,7 @@ auto Server::handleChatRequest(const ChatRequest &request) for (int i = 0; i < request.n; ++i) { ChatPromptResult result; try { - result = promptInternalChat(m_collections, promptCtx); + result = promptInternalChat(m_collections, promptCtx, chatItems); } catch (const std::exception &e) { emit responseChanged(e.what()); emit responseStopped(0); diff --git a/gpt4all-chat/tests/python/test_server_api.py b/gpt4all-chat/tests/python/test_server_api.py index 26a6aff7..449d8e25 100644 --- a/gpt4all-chat/tests/python/test_server_api.py +++ b/gpt4all-chat/tests/python/test_server_api.py @@ -252,8 +252,8 @@ def test_with_models(chat_server_with_model: None) -> None: assert response == EXPECTED_COMPLETIONS_RESPONSE -@pytest.mark.xfail(reason='Assertion failure in GPT4All. See nomic-ai/gpt4all#3133') def test_with_models_temperature(chat_server_with_model: None) -> None: + """Fixed by nomic-ai/gpt4all#3202.""" data = { 'model': 'Llama 3.2 1B Instruct', 'prompt': 'The quick brown fox',