Fix local server regressions caused by Jinja PR (#3256)

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-12-13 12:19:47 -05:00 committed by GitHub
parent 2c5097c9de
commit f67b370f5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 88 additions and 66 deletions

View File

@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
## [Unreleased]
### Fixed
- Fix API server ignoring assistant messages in history after v3.5.0 ([#3256](https://github.com/nomic-ai/gpt4all/pull/3256))
- Fix API server replying with incorrect token counts and stop reason after v3.5.0 ([#3256](https://github.com/nomic-ai/gpt4all/pull/3256))
- Fix API server remembering previous, unrelated conversations after v3.5.0 ([#3256](https://github.com/nomic-ai/gpt4all/pull/3256))
## [3.5.1] - 2024-12-10 ## [3.5.1] - 2024-12-10
### Fixed ### Fixed
@ -211,6 +218,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Fix several Vulkan resource management issues ([#2694](https://github.com/nomic-ai/gpt4all/pull/2694)) - Fix several Vulkan resource management issues ([#2694](https://github.com/nomic-ai/gpt4all/pull/2694))
- Fix crash/hang when some models stop generating, by showing special tokens ([#2701](https://github.com/nomic-ai/gpt4all/pull/2701)) - Fix crash/hang when some models stop generating, by showing special tokens ([#2701](https://github.com/nomic-ai/gpt4all/pull/2701))
[Unreleased]: https://github.com/nomic-ai/gpt4all/compare/v3.5.1...HEAD
[3.5.1]: https://github.com/nomic-ai/gpt4all/compare/v3.5.0...v3.5.1 [3.5.1]: https://github.com/nomic-ai/gpt4all/compare/v3.5.0...v3.5.1
[3.5.0]: https://github.com/nomic-ai/gpt4all/compare/v3.5.0-rc2...v3.5.0 [3.5.0]: https://github.com/nomic-ai/gpt4all/compare/v3.5.0-rc2...v3.5.0
[3.5.0-rc2]: https://github.com/nomic-ai/gpt4all/compare/v3.5.0-rc1...v3.5.0-rc2 [3.5.0-rc2]: https://github.com/nomic-ai/gpt4all/compare/v3.5.0-rc1...v3.5.0-rc2

View File

@ -851,39 +851,44 @@ std::string ChatLLM::applyJinjaTemplate(std::span<const ChatItem> items) const
return *maybeRendered; return *maybeRendered;
} }
auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx) auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx,
-> ChatPromptResult std::optional<QList<ChatItem>> chat) -> ChatPromptResult
{ {
Q_ASSERT(isModelLoaded()); Q_ASSERT(isModelLoaded());
Q_ASSERT(m_chatModel); Q_ASSERT(m_chatModel);
QList<ResultInfo> databaseResults;
const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize();
if (!enabledCollections.isEmpty()) {
std::optional<std::pair<int, QString>> query;
{
// Find the prompt that represents the query. Server chats are flexible and may not have one.
auto items = m_chatModel->chatItems(); // holds lock
Q_ASSERT(items);
auto response = items.end() - 1;
if (auto peer = m_chatModel->getPeerUnlocked(response))
query = {*peer - items.begin(), (*peer)->value};
}
if (query) {
auto &[promptIndex, queryStr] = *query;
emit requestRetrieveFromDB(enabledCollections, queryStr, retrievalSize, &databaseResults); // blocks
m_chatModel->updateSources(promptIndex, databaseResults);
emit databaseResultsChanged(databaseResults);
}
}
// copy messages for safety (since we can't hold the lock the whole time) // copy messages for safety (since we can't hold the lock the whole time)
std::optional<std::pair<int, QString>> query;
std::vector<ChatItem> chatItems; std::vector<ChatItem> chatItems;
{ {
auto items = m_chatModel->chatItems(); // holds lock std::optional<ChatModelAccessor> items;
Q_ASSERT(items.size() >= 2); // should be prompt/response pairs std::span<const ChatItem> view;
chatItems.assign(items.begin(), items.end() - 1); // exclude last if (chat) {
view = *chat;
} else {
items = m_chatModel->chatItems(); // holds lock
Q_ASSERT(!items->empty());
view = *items;
}
Q_ASSERT(view.size() >= 2); // should be prompt/response pairs
// Find the prompt that represents the query. Server chats are flexible and may not have one.
auto response = view.end() - 1;
if (auto peer = m_chatModel->getPeer(view, response))
query = { *peer - view.begin(), (*peer)->value };
chatItems.assign(view.begin(), view.end() - 1); // exclude last
} }
QList<ResultInfo> databaseResults;
if (query && !enabledCollections.isEmpty()) {
auto &[promptIndex, queryStr] = *query;
const int retrievalSize = MySettings::globalInstance()->localDocsRetrievalSize();
emit requestRetrieveFromDB(enabledCollections, queryStr, retrievalSize, &databaseResults); // blocks
m_chatModel->updateSources(promptIndex, databaseResults);
emit databaseResultsChanged(databaseResults);
}
auto result = promptInternal(chatItems, ctx, !databaseResults.isEmpty()); auto result = promptInternal(chatItems, ctx, !databaseResults.isEmpty());
return { return {
/*PromptResult*/ { /*PromptResult*/ {
@ -933,7 +938,7 @@ auto ChatLLM::promptInternal(
} }
} }
PromptResult result; PromptResult result {};
auto handlePrompt = [this, &result](std::span<const LLModel::Token> batch, bool cached) -> bool { auto handlePrompt = [this, &result](std::span<const LLModel::Token> batch, bool cached) -> bool {
Q_UNUSED(cached) Q_UNUSED(cached)

View File

@ -250,7 +250,8 @@ protected:
QList<ResultInfo> databaseResults; QList<ResultInfo> databaseResults;
}; };
ChatPromptResult promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx); ChatPromptResult promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx,
std::optional<QList<ChatItem>> chat = {});
// passing a string_view directly skips templating and uses the raw string // passing a string_view directly skips templating and uses the raw string
PromptResult promptInternal(const std::variant<std::span<const ChatItem>, std::string_view> &prompt, PromptResult promptInternal(const std::variant<std::span<const ChatItem>, std::string_view> &prompt,
const LLModel::PromptContext &ctx, const LLModel::PromptContext &ctx,

View File

@ -111,8 +111,13 @@ public:
ChatItem(prompt_tag_t, const QString &value, const QList<PromptAttachment> &attachments = {}) ChatItem(prompt_tag_t, const QString &value, const QList<PromptAttachment> &attachments = {})
: name(u"Prompt: "_s), value(value), promptAttachments(attachments) {} : name(u"Prompt: "_s), value(value), promptAttachments(attachments) {}
ChatItem(response_tag_t, bool isCurrentResponse = true) // A new response, to be filled in
: name(u"Response: "_s), isCurrentResponse(isCurrentResponse) {} ChatItem(response_tag_t)
: name(u"Response: "_s), isCurrentResponse(true) {}
// An existing response, from Server
ChatItem(response_tag_t, const QString &value)
: name(u"Response: "_s), value(value) {}
Type type() const Type type() const
{ {
@ -165,9 +170,9 @@ public:
}; };
Q_DECLARE_METATYPE(ChatItem) Q_DECLARE_METATYPE(ChatItem)
class ChatModelAccessor : public ranges::subrange<QList<ChatItem>::const_iterator> { class ChatModelAccessor : public std::span<const ChatItem> {
private: private:
using Super = ranges::subrange<QList<ChatItem>::const_iterator>; using Super = std::span<const ChatItem>;
public: public:
template <typename... T> template <typename... T>
@ -219,38 +224,38 @@ public:
/* a "peer" is a bidirectional 1:1 link between a prompt and the response that would cite its LocalDocs /* a "peer" is a bidirectional 1:1 link between a prompt and the response that would cite its LocalDocs
* sources. Return std::nullopt if there is none, which is possible for e.g. server chats. */ * sources. Return std::nullopt if there is none, which is possible for e.g. server chats. */
auto getPeerUnlocked(QList<ChatItem>::const_iterator item) const static std::optional<qsizetype> getPeer(const ChatItem *arr, qsizetype size, qsizetype index)
-> std::optional<QList<ChatItem>::const_iterator>
{ {
switch (item->type()) { Q_ASSERT(index >= 0);
Q_ASSERT(index < size);
qsizetype peer;
ChatItem::Type expected;
switch (arr[index].type()) {
using enum ChatItem::Type; using enum ChatItem::Type;
case Prompt: case Prompt: peer = index + 1; expected = Response; break;
{ case Response: peer = index - 1; expected = Prompt; break;
auto peer = std::next(item); default: throw std::invalid_argument("getPeer() called on item that is not a prompt or response");
if (peer < m_chatItems.cend() && peer->type() == Response)
return peer;
break;
}
case Response:
{
if (item > m_chatItems.cbegin()) {
if (auto peer = std::prev(item); peer->type() == Prompt)
return peer;
}
break;
}
default:
throw std::invalid_argument("getPeer() called on item that is not a prompt or response");
} }
if (peer >= 0 && peer < size && arr[peer].type() == expected)
return peer;
return std::nullopt; return std::nullopt;
} }
auto getPeerUnlocked(int index) const -> std::optional<int> template <ranges::contiguous_range R>
requires std::same_as<ranges::range_value_t<R>, ChatItem>
static auto getPeer(R &&range, ranges::iterator_t<R> item) -> std::optional<ranges::iterator_t<R>>
{ {
return getPeerUnlocked(m_chatItems.cbegin() + index) auto begin = ranges::begin(range);
.transform([&](auto &&i) { return i - m_chatItems.cbegin(); } ); return getPeer(ranges::data(range), ranges::size(range), item - begin)
.transform([&](auto i) { return begin + i; });
} }
auto getPeerUnlocked(QList<ChatItem>::const_iterator item) const -> std::optional<QList<ChatItem>::const_iterator>
{ return getPeer(m_chatItems, item); }
std::optional<qsizetype> getPeerUnlocked(qsizetype index) const
{ return getPeer(m_chatItems.constData(), m_chatItems.size(), index); }
QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override
{ {
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);
@ -356,26 +361,29 @@ public:
} }
// Used by Server to append a new conversation to the chat log. // Used by Server to append a new conversation to the chat log.
void appendResponseWithHistory(std::span<const ChatItem> history) // Appends a new, blank response to the end of the input list.
void appendResponseWithHistory(QList<ChatItem> &history)
{ {
if (history.empty()) if (history.empty())
throw std::invalid_argument("at least one message is required"); throw std::invalid_argument("at least one message is required");
// add an empty response to prepare for generation
history.emplace_back(ChatItem::response_tag);
m_mutex.lock(); m_mutex.lock();
qsizetype startIndex = m_chatItems.count(); qsizetype startIndex = m_chatItems.size();
m_mutex.unlock(); m_mutex.unlock();
qsizetype nNewItems = history.size() + 1; qsizetype endIndex = startIndex + history.size();
qsizetype endIndex = startIndex + nNewItems;
beginInsertRows(QModelIndex(), startIndex, endIndex - 1 /*inclusive*/); beginInsertRows(QModelIndex(), startIndex, endIndex - 1 /*inclusive*/);
bool hadError; bool hadError;
QList<ChatItem> newItems;
{ {
QMutexLocker locker(&m_mutex); QMutexLocker locker(&m_mutex);
hadError = hasErrorUnlocked(); hadError = hasErrorUnlocked();
m_chatItems.reserve(m_chatItems.count() + nNewItems); m_chatItems.reserve(m_chatItems.size() + history.size());
for (auto &item : history) for (auto &item : history)
m_chatItems << item; m_chatItems << item;
m_chatItems.emplace_back(ChatItem::response_tag);
} }
endInsertRows(); endInsertRows();
emit countChanged(); emit countChanged();

View File

@ -771,13 +771,13 @@ auto Server::handleChatRequest(const ChatRequest &request)
Q_ASSERT(!request.messages.isEmpty()); Q_ASSERT(!request.messages.isEmpty());
// adds prompt/response items to GUI // adds prompt/response items to GUI
std::vector<ChatItem> chatItems; QList<ChatItem> chatItems;
for (auto &message : request.messages) { for (auto &message : request.messages) {
using enum ChatRequest::Message::Role; using enum ChatRequest::Message::Role;
switch (message.role) { switch (message.role) {
case System: chatItems.emplace_back(ChatItem::system_tag, message.content); break; case System: chatItems.emplace_back(ChatItem::system_tag, message.content); break;
case User: chatItems.emplace_back(ChatItem::prompt_tag, message.content); break; case User: chatItems.emplace_back(ChatItem::prompt_tag, message.content); break;
case Assistant: chatItems.emplace_back(ChatItem::response_tag, /*currentResponse*/ false); break; case Assistant: chatItems.emplace_back(ChatItem::response_tag, message.content); break;
} }
} }
m_chatModel->appendResponseWithHistory(chatItems); m_chatModel->appendResponseWithHistory(chatItems);
@ -800,7 +800,7 @@ auto Server::handleChatRequest(const ChatRequest &request)
for (int i = 0; i < request.n; ++i) { for (int i = 0; i < request.n; ++i) {
ChatPromptResult result; ChatPromptResult result;
try { try {
result = promptInternalChat(m_collections, promptCtx); result = promptInternalChat(m_collections, promptCtx, chatItems);
} catch (const std::exception &e) { } catch (const std::exception &e) {
emit responseChanged(e.what()); emit responseChanged(e.what());
emit responseStopped(0); emit responseStopped(0);

View File

@ -252,8 +252,8 @@ def test_with_models(chat_server_with_model: None) -> None:
assert response == EXPECTED_COMPLETIONS_RESPONSE assert response == EXPECTED_COMPLETIONS_RESPONSE
@pytest.mark.xfail(reason='Assertion failure in GPT4All. See nomic-ai/gpt4all#3133')
def test_with_models_temperature(chat_server_with_model: None) -> None: def test_with_models_temperature(chat_server_with_model: None) -> None:
"""Fixed by nomic-ai/gpt4all#3202."""
data = { data = {
'model': 'Llama 3.2 1B Instruct', 'model': 'Llama 3.2 1B Instruct',
'prompt': 'The quick brown fox', 'prompt': 'The quick brown fox',