From c11b67dfcbfbef027f64df70713466947f60c1b5 Mon Sep 17 00:00:00 2001 From: AT Date: Tue, 1 Oct 2024 18:15:02 -0400 Subject: [PATCH] Make ChatModel threadsafe to support direct access by ChatLLM (#3018) Signed-off-by: Adam Treat Signed-off-by: Jared Van Bortel Co-authored-by: Jared Van Bortel --- gpt4all-chat/CHANGELOG.md | 1 + gpt4all-chat/src/chat.cpp | 6 +- gpt4all-chat/src/chat.h | 1 + gpt4all-chat/src/chatllm.cpp | 27 ++++-- gpt4all-chat/src/chatllm.h | 7 +- gpt4all-chat/src/chatmodel.h | 171 ++++++++++++++++++++++++----------- 6 files changed, 140 insertions(+), 73 deletions(-) diff --git a/gpt4all-chat/CHANGELOG.md b/gpt4all-chat/CHANGELOG.md index ddd909b9..8f341fb9 100644 --- a/gpt4all-chat/CHANGELOG.md +++ b/gpt4all-chat/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). - Rebase llama.cpp on latest upstream as of September 26th ([#2998](https://github.com/nomic-ai/gpt4all/pull/2998)) - Change the error message when a message is too long ([#3004](https://github.com/nomic-ai/gpt4all/pull/3004)) - Simplify chatmodel to get rid of unnecessary field and bump chat version ([#3016](https://github.com/nomic-ai/gpt4all/pull/3016)) +- Allow ChatLLM to have direct access to ChatModel for restoring state from text ([#3018](https://github.com/nomic-ai/gpt4all/pull/3018)) ### Fixed - Fix a crash when attempting to continue a chat loaded from disk ([#2995](https://github.com/nomic-ai/gpt4all/pull/2995)) diff --git a/gpt4all-chat/src/chat.cpp b/gpt4all-chat/src/chat.cpp index 6ee7b30f..fb5e7763 100644 --- a/gpt4all-chat/src/chat.cpp +++ b/gpt4all-chat/src/chat.cpp @@ -6,15 +6,13 @@ #include "server.h" #include -#include #include #include #include #include #include -#include +#include #include -#include #include #include @@ -443,8 +441,6 @@ bool Chat::deserialize(QDataStream &stream, int version) if (!m_chatModel->deserialize(stream, version)) return false; - m_llmodel->setStateFromText(m_chatModel->text()); - emit chatModelChanged(); return stream.status() == QDataStream::Ok; } diff --git a/gpt4all-chat/src/chat.h b/gpt4all-chat/src/chat.h index e3cdc756..21f794de 100644 --- a/gpt4all-chat/src/chat.h +++ b/gpt4all-chat/src/chat.h @@ -7,6 +7,7 @@ #include "localdocsmodel.h" // IWYU pragma: keep #include "modellist.h" +#include #include #include #include diff --git a/gpt4all-chat/src/chatllm.cpp b/gpt4all-chat/src/chatllm.cpp index abbdf322..fece36fb 100644 --- a/gpt4all-chat/src/chatllm.cpp +++ b/gpt4all-chat/src/chatllm.cpp @@ -2,6 +2,7 @@ #include "chat.h" #include "chatapi.h" +#include "chatmodel.h" #include "localdocs.h" #include "mysettings.h" #include "network.h" @@ -13,10 +14,14 @@ #include #include #include +#include +#include #include #include +#include #include #include +#include #include #include #include @@ -113,6 +118,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_reloadingToChangeVariant(false) , m_processedSystemPrompt(false) , m_restoreStateFromText(false) + , m_chatModel(parent->chatModel()) { moveToThread(&m_llmThread); connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, @@ -1313,31 +1319,32 @@ void ChatLLM::processRestoreStateFromText() m_ctx.repeat_last_n = repeat_penalty_tokens; m_llModelInfo.model->setThreadCount(n_threads); - auto it = m_stateFromText.begin(); - while (it < m_stateFromText.end()) { + Q_ASSERT(m_chatModel); + m_chatModel->lock(); + auto it = m_chatModel->begin(); + while (it < m_chatModel->end()) { auto &prompt = *it++; - Q_ASSERT(prompt.first == "Prompt: "); - Q_ASSERT(it < m_stateFromText.end()); + Q_ASSERT(prompt.name == "Prompt: "); + Q_ASSERT(it < m_chatModel->end()); auto &response = *it++; - Q_ASSERT(response.first != "Prompt: "); + Q_ASSERT(response.name == "Response: "); // FIXME(jared): this doesn't work well with the "regenerate" button since we are not incrementing // m_promptTokens or m_promptResponseTokens m_llModelInfo.model->prompt( - prompt.second.toStdString(), promptTemplate.toStdString(), + prompt.value.toStdString(), promptTemplate.toStdString(), promptFunc, /*responseFunc*/ [](auto &&...) { return true; }, /*allowContextShift*/ true, m_ctx, /*special*/ false, - response.second.toUtf8().constData() + response.value.toUtf8().constData() ); } + m_chatModel->unlock(); - if (!m_stopGenerating) { + if (!m_stopGenerating) m_restoreStateFromText = false; - m_stateFromText.clear(); - } m_restoringFromText = false; emit restoringFromTextChanged(); diff --git a/gpt4all-chat/src/chatllm.h b/gpt4all-chat/src/chatllm.h index 62d83753..c486b06c 100644 --- a/gpt4all-chat/src/chatllm.h +++ b/gpt4all-chat/src/chatllm.h @@ -11,11 +11,10 @@ #include #include #include -#include +#include #include #include #include -#include #include #include @@ -37,6 +36,7 @@ enum LLModelType { }; class ChatLLM; +class ChatModel; struct LLModelInfo { std::unique_ptr model; @@ -151,7 +151,6 @@ public: bool serialize(QDataStream &stream, int version, bool serializeKV); bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV); - void setStateFromText(const QVector> &stateFromText) { m_stateFromText = stateFromText; } public Q_SLOTS: bool prompt(const QList &collectionList, const QString &prompt); @@ -244,7 +243,7 @@ private: // - an unload was queued during LLModel::restoreState() // - the chat will be restored from text and hasn't been interacted with yet bool m_pristineLoadedState = false; - QVector> m_stateFromText; + QPointer m_chatModel; }; #endif // CHATLLM_H diff --git a/gpt4all-chat/src/chatmodel.h b/gpt4all-chat/src/chatmodel.h index 2b441a79..d971ad26 100644 --- a/gpt4all-chat/src/chatmodel.h +++ b/gpt4all-chat/src/chatmodel.h @@ -45,6 +45,8 @@ public: }; Q_DECLARE_METATYPE(ChatItem) +using ChatModelIterator = QList::const_iterator; + class ChatModel : public QAbstractListModel { Q_OBJECT @@ -68,12 +70,14 @@ public: int rowCount(const QModelIndex &parent = QModelIndex()) const override { + QMutexLocker locker(&m_mutex); Q_UNUSED(parent) return m_chatItems.size(); } QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override { + QMutexLocker locker(&m_mutex); if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size()) return QVariant(); @@ -125,75 +129,112 @@ public: ChatItem item; item.name = name; item.value = value; - beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); - m_chatItems.append(item); + m_mutex.lock(); + const int count = m_chatItems.count(); + m_mutex.unlock(); + beginInsertRows(QModelIndex(), count, count); + { + QMutexLocker locker(&m_mutex); + m_chatItems.append(item); + } endInsertRows(); emit countChanged(); } void appendResponse(const QString &name) { + m_mutex.lock(); + const int count = m_chatItems.count(); + m_mutex.unlock(); ChatItem item; - item.id = m_chatItems.count(); // This is only relevant for responses + item.id = count; // This is only relevant for responses item.name = name; item.currentResponse = true; - beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); - m_chatItems.append(item); + beginInsertRows(QModelIndex(), count, count); + { + QMutexLocker locker(&m_mutex); + m_chatItems.append(item); + } endInsertRows(); emit countChanged(); } Q_INVOKABLE void clear() { - if (m_chatItems.isEmpty()) return; + { + QMutexLocker locker(&m_mutex); + if (m_chatItems.isEmpty()) return; + } beginResetModel(); - m_chatItems.clear(); + { + QMutexLocker locker(&m_mutex); + m_chatItems.clear(); + } endResetModel(); emit countChanged(); } Q_INVOKABLE ChatItem get(int index) { + QMutexLocker locker(&m_mutex); if (index < 0 || index >= m_chatItems.size()) return ChatItem(); return m_chatItems.at(index); } Q_INVOKABLE void updateCurrentResponse(int index, bool b) { - if (index < 0 || index >= m_chatItems.size()) return; + bool changed = false; + { + QMutexLocker locker(&m_mutex); + if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.currentResponse != b) { - item.currentResponse = b; - emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole}); + ChatItem &item = m_chatItems[index]; + if (item.currentResponse != b) { + item.currentResponse = b; + changed = true; + } } + + if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole}); } Q_INVOKABLE void updateStopped(int index, bool b) { - if (index < 0 || index >= m_chatItems.size()) return; + bool changed = false; + { + QMutexLocker locker(&m_mutex); + if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.stopped != b) { - item.stopped = b; - emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole}); + ChatItem &item = m_chatItems[index]; + if (item.stopped != b) { + item.stopped = b; + changed = true; + } } + if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole}); } Q_INVOKABLE void updateValue(int index, const QString &value) { - if (index < 0 || index >= m_chatItems.size()) return; + bool changed = false; + { + QMutexLocker locker(&m_mutex); + if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.value != value) { - item.value = value; + ChatItem &item = m_chatItems[index]; + if (item.value != value) { + item.value = value; + changed = true; + } + } + if (changed) { emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole}); emit valueChanged(index, value); } } - QList consolidateSources(const QList &sources) { + static QList consolidateSources(const QList &sources) { QMap groupedData; for (const ResultInfo &info : sources) { if (groupedData.contains(info.file)) { @@ -208,53 +249,77 @@ public: Q_INVOKABLE void updateSources(int index, const QList &sources) { - if (index < 0 || index >= m_chatItems.size()) return; + { + QMutexLocker locker(&m_mutex); + if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - item.sources = sources; - item.consolidatedSources = consolidateSources(sources); + ChatItem &item = m_chatItems[index]; + item.sources = sources; + item.consolidatedSources = consolidateSources(sources); + } emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole}); emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole}); } Q_INVOKABLE void updateThumbsUpState(int index, bool b) { - if (index < 0 || index >= m_chatItems.size()) return; + bool changed = false; + { + QMutexLocker locker(&m_mutex); + if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.thumbsUpState != b) { - item.thumbsUpState = b; - emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole}); + ChatItem &item = m_chatItems[index]; + if (item.thumbsUpState != b) { + item.thumbsUpState = b; + changed = true; + } } + if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole}); } Q_INVOKABLE void updateThumbsDownState(int index, bool b) { - if (index < 0 || index >= m_chatItems.size()) return; + bool changed = false; + { + QMutexLocker locker(&m_mutex); + if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.thumbsDownState != b) { - item.thumbsDownState = b; - emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole}); + ChatItem &item = m_chatItems[index]; + if (item.thumbsDownState != b) { + item.thumbsDownState = b; + changed = true; + } } + if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole}); } Q_INVOKABLE void updateNewResponse(int index, const QString &newResponse) { - if (index < 0 || index >= m_chatItems.size()) return; + bool changed = false; + { + QMutexLocker locker(&m_mutex); + if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.newResponse != newResponse) { - item.newResponse = newResponse; - emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole}); + ChatItem &item = m_chatItems[index]; + if (item.newResponse != newResponse) { + item.newResponse = newResponse; + changed = true; + } } + if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole}); } - int count() const { return m_chatItems.size(); } + int count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); } + + ChatModelIterator begin() const { return m_chatItems.begin(); } + ChatModelIterator end() const { return m_chatItems.end(); } + void lock() { m_mutex.lock(); } + void unlock() { m_mutex.unlock(); } bool serialize(QDataStream &stream, int version) const { - stream << count(); + QMutexLocker locker(&m_mutex); + stream << int(m_chatItems.size()); for (const auto &c : m_chatItems) { stream << c.id; stream << c.name; @@ -442,28 +507,26 @@ public: c.consolidatedSources = consolidateSources(sources); } } - beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); - m_chatItems.append(c); + m_mutex.lock(); + const int count = m_chatItems.size(); + m_mutex.unlock(); + beginInsertRows(QModelIndex(), count, count); + { + QMutexLocker locker(&m_mutex); + m_chatItems.append(c); + } endInsertRows(); } emit countChanged(); return stream.status() == QDataStream::Ok; } - QVector> text() const - { - QVector> result; - for (const auto &c : m_chatItems) - result << qMakePair(c.name, c.value); - return result; - } - Q_SIGNALS: void countChanged(); void valueChanged(int index, const QString &value); private: - + mutable QMutex m_mutex; QList m_chatItems; };