Make ChatModel threadsafe to support direct access by ChatLLM (#3018)

Signed-off-by: Adam Treat <treat.adam@gmail.com>
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
AT 2024-10-01 18:15:02 -04:00 committed by GitHub
parent ee67cca885
commit c11b67dfcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 140 additions and 73 deletions

View File

@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Rebase llama.cpp on latest upstream as of September 26th ([#2998](https://github.com/nomic-ai/gpt4all/pull/2998)) - Rebase llama.cpp on latest upstream as of September 26th ([#2998](https://github.com/nomic-ai/gpt4all/pull/2998))
- Change the error message when a message is too long ([#3004](https://github.com/nomic-ai/gpt4all/pull/3004)) - Change the error message when a message is too long ([#3004](https://github.com/nomic-ai/gpt4all/pull/3004))
- Simplify chatmodel to get rid of unnecessary field and bump chat version ([#3016](https://github.com/nomic-ai/gpt4all/pull/3016)) - Simplify chatmodel to get rid of unnecessary field and bump chat version ([#3016](https://github.com/nomic-ai/gpt4all/pull/3016))
- Allow ChatLLM to have direct access to ChatModel for restoring state from text ([#3018](https://github.com/nomic-ai/gpt4all/pull/3018))
### Fixed ### Fixed
- Fix a crash when attempting to continue a chat loaded from disk ([#2995](https://github.com/nomic-ai/gpt4all/pull/2995)) - Fix a crash when attempting to continue a chat loaded from disk ([#2995](https://github.com/nomic-ai/gpt4all/pull/2995))

View File

@ -6,15 +6,13 @@
#include "server.h" #include "server.h"
#include <QDataStream> #include <QDataStream>
#include <QDateTime>
#include <QDebug> #include <QDebug>
#include <QLatin1String> #include <QLatin1String>
#include <QMap> #include <QMap>
#include <QString> #include <QString>
#include <QStringList> #include <QStringList>
#include <QTextStream> #include <QVariant>
#include <Qt> #include <Qt>
#include <QtGlobal>
#include <QtLogging> #include <QtLogging>
#include <utility> #include <utility>
@ -443,8 +441,6 @@ bool Chat::deserialize(QDataStream &stream, int version)
if (!m_chatModel->deserialize(stream, version)) if (!m_chatModel->deserialize(stream, version))
return false; return false;
m_llmodel->setStateFromText(m_chatModel->text());
emit chatModelChanged(); emit chatModelChanged();
return stream.status() == QDataStream::Ok; return stream.status() == QDataStream::Ok;
} }

View File

@ -7,6 +7,7 @@
#include "localdocsmodel.h" // IWYU pragma: keep #include "localdocsmodel.h" // IWYU pragma: keep
#include "modellist.h" #include "modellist.h"
#include <QDateTime>
#include <QList> #include <QList>
#include <QObject> #include <QObject>
#include <QQmlEngine> #include <QQmlEngine>

View File

@ -2,6 +2,7 @@
#include "chat.h" #include "chat.h"
#include "chatapi.h" #include "chatapi.h"
#include "chatmodel.h"
#include "localdocs.h" #include "localdocs.h"
#include "mysettings.h" #include "mysettings.h"
#include "network.h" #include "network.h"
@ -13,10 +14,14 @@
#include <QIODevice> #include <QIODevice>
#include <QJsonDocument> #include <QJsonDocument>
#include <QJsonObject> #include <QJsonObject>
#include <QJsonValue>
#include <QMap>
#include <QMutex> #include <QMutex>
#include <QMutexLocker> #include <QMutexLocker>
#include <QRegularExpression>
#include <QSet> #include <QSet>
#include <QStringList> #include <QStringList>
#include <QUrl>
#include <QWaitCondition> #include <QWaitCondition>
#include <Qt> #include <Qt>
#include <QtLogging> #include <QtLogging>
@ -113,6 +118,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_reloadingToChangeVariant(false) , m_reloadingToChangeVariant(false)
, m_processedSystemPrompt(false) , m_processedSystemPrompt(false)
, m_restoreStateFromText(false) , m_restoreStateFromText(false)
, m_chatModel(parent->chatModel())
{ {
moveToThread(&m_llmThread); moveToThread(&m_llmThread);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
@ -1313,31 +1319,32 @@ void ChatLLM::processRestoreStateFromText()
m_ctx.repeat_last_n = repeat_penalty_tokens; m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads); m_llModelInfo.model->setThreadCount(n_threads);
auto it = m_stateFromText.begin(); Q_ASSERT(m_chatModel);
while (it < m_stateFromText.end()) { m_chatModel->lock();
auto it = m_chatModel->begin();
while (it < m_chatModel->end()) {
auto &prompt = *it++; auto &prompt = *it++;
Q_ASSERT(prompt.first == "Prompt: "); Q_ASSERT(prompt.name == "Prompt: ");
Q_ASSERT(it < m_stateFromText.end()); Q_ASSERT(it < m_chatModel->end());
auto &response = *it++; auto &response = *it++;
Q_ASSERT(response.first != "Prompt: "); Q_ASSERT(response.name == "Response: ");
// FIXME(jared): this doesn't work well with the "regenerate" button since we are not incrementing // FIXME(jared): this doesn't work well with the "regenerate" button since we are not incrementing
// m_promptTokens or m_promptResponseTokens // m_promptTokens or m_promptResponseTokens
m_llModelInfo.model->prompt( m_llModelInfo.model->prompt(
prompt.second.toStdString(), promptTemplate.toStdString(), prompt.value.toStdString(), promptTemplate.toStdString(),
promptFunc, /*responseFunc*/ [](auto &&...) { return true; }, promptFunc, /*responseFunc*/ [](auto &&...) { return true; },
/*allowContextShift*/ true, /*allowContextShift*/ true,
m_ctx, m_ctx,
/*special*/ false, /*special*/ false,
response.second.toUtf8().constData() response.value.toUtf8().constData()
); );
} }
m_chatModel->unlock();
if (!m_stopGenerating) { if (!m_stopGenerating)
m_restoreStateFromText = false; m_restoreStateFromText = false;
m_stateFromText.clear();
}
m_restoringFromText = false; m_restoringFromText = false;
emit restoringFromTextChanged(); emit restoringFromTextChanged();

View File

@ -11,11 +11,10 @@
#include <QFileInfo> #include <QFileInfo>
#include <QList> #include <QList>
#include <QObject> #include <QObject>
#include <QPair> #include <QPointer>
#include <QString> #include <QString>
#include <QThread> #include <QThread>
#include <QVariantMap> #include <QVariantMap>
#include <QVector>
#include <QtGlobal> #include <QtGlobal>
#include <atomic> #include <atomic>
@ -37,6 +36,7 @@ enum LLModelType {
}; };
class ChatLLM; class ChatLLM;
class ChatModel;
struct LLModelInfo { struct LLModelInfo {
std::unique_ptr<LLModel> model; std::unique_ptr<LLModel> model;
@ -151,7 +151,6 @@ public:
bool serialize(QDataStream &stream, int version, bool serializeKV); bool serialize(QDataStream &stream, int version, bool serializeKV);
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV); bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) { m_stateFromText = stateFromText; }
public Q_SLOTS: public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt); bool prompt(const QList<QString> &collectionList, const QString &prompt);
@ -244,7 +243,7 @@ private:
// - an unload was queued during LLModel::restoreState() // - an unload was queued during LLModel::restoreState()
// - the chat will be restored from text and hasn't been interacted with yet // - the chat will be restored from text and hasn't been interacted with yet
bool m_pristineLoadedState = false; bool m_pristineLoadedState = false;
QVector<QPair<QString, QString>> m_stateFromText; QPointer<ChatModel> m_chatModel;
}; };
#endif // CHATLLM_H #endif // CHATLLM_H

View File

@ -45,6 +45,8 @@ public:
}; };
Q_DECLARE_METATYPE(ChatItem) Q_DECLARE_METATYPE(ChatItem)
using ChatModelIterator = QList<ChatItem>::const_iterator;
class ChatModel : public QAbstractListModel class ChatModel : public QAbstractListModel
{ {
Q_OBJECT Q_OBJECT
@ -68,12 +70,14 @@ public:
int rowCount(const QModelIndex &parent = QModelIndex()) const override int rowCount(const QModelIndex &parent = QModelIndex()) const override
{ {
QMutexLocker locker(&m_mutex);
Q_UNUSED(parent) Q_UNUSED(parent)
return m_chatItems.size(); return m_chatItems.size();
} }
QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override
{ {
QMutexLocker locker(&m_mutex);
if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size()) if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size())
return QVariant(); return QVariant();
@ -125,75 +129,112 @@ public:
ChatItem item; ChatItem item;
item.name = name; item.name = name;
item.value = value; item.value = value;
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); m_mutex.lock();
m_chatItems.append(item); const int count = m_chatItems.count();
m_mutex.unlock();
beginInsertRows(QModelIndex(), count, count);
{
QMutexLocker locker(&m_mutex);
m_chatItems.append(item);
}
endInsertRows(); endInsertRows();
emit countChanged(); emit countChanged();
} }
void appendResponse(const QString &name) void appendResponse(const QString &name)
{ {
m_mutex.lock();
const int count = m_chatItems.count();
m_mutex.unlock();
ChatItem item; ChatItem item;
item.id = m_chatItems.count(); // This is only relevant for responses item.id = count; // This is only relevant for responses
item.name = name; item.name = name;
item.currentResponse = true; item.currentResponse = true;
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); beginInsertRows(QModelIndex(), count, count);
m_chatItems.append(item); {
QMutexLocker locker(&m_mutex);
m_chatItems.append(item);
}
endInsertRows(); endInsertRows();
emit countChanged(); emit countChanged();
} }
Q_INVOKABLE void clear() Q_INVOKABLE void clear()
{ {
if (m_chatItems.isEmpty()) return; {
QMutexLocker locker(&m_mutex);
if (m_chatItems.isEmpty()) return;
}
beginResetModel(); beginResetModel();
m_chatItems.clear(); {
QMutexLocker locker(&m_mutex);
m_chatItems.clear();
}
endResetModel(); endResetModel();
emit countChanged(); emit countChanged();
} }
Q_INVOKABLE ChatItem get(int index) Q_INVOKABLE ChatItem get(int index)
{ {
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return ChatItem(); if (index < 0 || index >= m_chatItems.size()) return ChatItem();
return m_chatItems.at(index); return m_chatItems.at(index);
} }
Q_INVOKABLE void updateCurrentResponse(int index, bool b) Q_INVOKABLE void updateCurrentResponse(int index, bool b)
{ {
if (index < 0 || index >= m_chatItems.size()) return; bool changed = false;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index]; ChatItem &item = m_chatItems[index];
if (item.currentResponse != b) { if (item.currentResponse != b) {
item.currentResponse = b; item.currentResponse = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole}); changed = true;
}
} }
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole});
} }
Q_INVOKABLE void updateStopped(int index, bool b) Q_INVOKABLE void updateStopped(int index, bool b)
{ {
if (index < 0 || index >= m_chatItems.size()) return; bool changed = false;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index]; ChatItem &item = m_chatItems[index];
if (item.stopped != b) { if (item.stopped != b) {
item.stopped = b; item.stopped = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole}); changed = true;
}
} }
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole});
} }
Q_INVOKABLE void updateValue(int index, const QString &value) Q_INVOKABLE void updateValue(int index, const QString &value)
{ {
if (index < 0 || index >= m_chatItems.size()) return; bool changed = false;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index]; ChatItem &item = m_chatItems[index];
if (item.value != value) { if (item.value != value) {
item.value = value; item.value = value;
changed = true;
}
}
if (changed) {
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole}); emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole});
emit valueChanged(index, value); emit valueChanged(index, value);
} }
} }
QList<ResultInfo> consolidateSources(const QList<ResultInfo> &sources) { static QList<ResultInfo> consolidateSources(const QList<ResultInfo> &sources) {
QMap<QString, ResultInfo> groupedData; QMap<QString, ResultInfo> groupedData;
for (const ResultInfo &info : sources) { for (const ResultInfo &info : sources) {
if (groupedData.contains(info.file)) { if (groupedData.contains(info.file)) {
@ -208,53 +249,77 @@ public:
Q_INVOKABLE void updateSources(int index, const QList<ResultInfo> &sources) Q_INVOKABLE void updateSources(int index, const QList<ResultInfo> &sources)
{ {
if (index < 0 || index >= m_chatItems.size()) return; {
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index]; ChatItem &item = m_chatItems[index];
item.sources = sources; item.sources = sources;
item.consolidatedSources = consolidateSources(sources); item.consolidatedSources = consolidateSources(sources);
}
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole}); emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole});
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole}); emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole});
} }
Q_INVOKABLE void updateThumbsUpState(int index, bool b) Q_INVOKABLE void updateThumbsUpState(int index, bool b)
{ {
if (index < 0 || index >= m_chatItems.size()) return; bool changed = false;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index]; ChatItem &item = m_chatItems[index];
if (item.thumbsUpState != b) { if (item.thumbsUpState != b) {
item.thumbsUpState = b; item.thumbsUpState = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole}); changed = true;
}
} }
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole});
} }
Q_INVOKABLE void updateThumbsDownState(int index, bool b) Q_INVOKABLE void updateThumbsDownState(int index, bool b)
{ {
if (index < 0 || index >= m_chatItems.size()) return; bool changed = false;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index]; ChatItem &item = m_chatItems[index];
if (item.thumbsDownState != b) { if (item.thumbsDownState != b) {
item.thumbsDownState = b; item.thumbsDownState = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole}); changed = true;
}
} }
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole});
} }
Q_INVOKABLE void updateNewResponse(int index, const QString &newResponse) Q_INVOKABLE void updateNewResponse(int index, const QString &newResponse)
{ {
if (index < 0 || index >= m_chatItems.size()) return; bool changed = false;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index]; ChatItem &item = m_chatItems[index];
if (item.newResponse != newResponse) { if (item.newResponse != newResponse) {
item.newResponse = newResponse; item.newResponse = newResponse;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole}); changed = true;
}
} }
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
} }
int count() const { return m_chatItems.size(); } int count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); }
ChatModelIterator begin() const { return m_chatItems.begin(); }
ChatModelIterator end() const { return m_chatItems.end(); }
void lock() { m_mutex.lock(); }
void unlock() { m_mutex.unlock(); }
bool serialize(QDataStream &stream, int version) const bool serialize(QDataStream &stream, int version) const
{ {
stream << count(); QMutexLocker locker(&m_mutex);
stream << int(m_chatItems.size());
for (const auto &c : m_chatItems) { for (const auto &c : m_chatItems) {
stream << c.id; stream << c.id;
stream << c.name; stream << c.name;
@ -442,28 +507,26 @@ public:
c.consolidatedSources = consolidateSources(sources); c.consolidatedSources = consolidateSources(sources);
} }
} }
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size()); m_mutex.lock();
m_chatItems.append(c); const int count = m_chatItems.size();
m_mutex.unlock();
beginInsertRows(QModelIndex(), count, count);
{
QMutexLocker locker(&m_mutex);
m_chatItems.append(c);
}
endInsertRows(); endInsertRows();
} }
emit countChanged(); emit countChanged();
return stream.status() == QDataStream::Ok; return stream.status() == QDataStream::Ok;
} }
QVector<QPair<QString, QString>> text() const
{
QVector<QPair<QString, QString>> result;
for (const auto &c : m_chatItems)
result << qMakePair(c.name, c.value);
return result;
}
Q_SIGNALS: Q_SIGNALS:
void countChanged(); void countChanged();
void valueChanged(int index, const QString &value); void valueChanged(int index, const QString &value);
private: private:
mutable QMutex m_mutex;
QList<ChatItem> m_chatItems; QList<ChatItem> m_chatItems;
}; };