Make ChatModel threadsafe to support direct access by ChatLLM (#3018)

Signed-off-by: Adam Treat <treat.adam@gmail.com>
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
AT 2024-10-01 18:15:02 -04:00 committed by GitHub
parent ee67cca885
commit c11b67dfcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 140 additions and 73 deletions

View File

@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Rebase llama.cpp on latest upstream as of September 26th ([#2998](https://github.com/nomic-ai/gpt4all/pull/2998))
- Change the error message when a message is too long ([#3004](https://github.com/nomic-ai/gpt4all/pull/3004))
- Simplify chatmodel to get rid of unnecessary field and bump chat version ([#3016](https://github.com/nomic-ai/gpt4all/pull/3016))
- Allow ChatLLM to have direct access to ChatModel for restoring state from text ([#3018](https://github.com/nomic-ai/gpt4all/pull/3018))
### Fixed
- Fix a crash when attempting to continue a chat loaded from disk ([#2995](https://github.com/nomic-ai/gpt4all/pull/2995))

View File

@ -6,15 +6,13 @@
#include "server.h"
#include <QDataStream>
#include <QDateTime>
#include <QDebug>
#include <QLatin1String>
#include <QMap>
#include <QString>
#include <QStringList>
#include <QTextStream>
#include <QVariant>
#include <Qt>
#include <QtGlobal>
#include <QtLogging>
#include <utility>
@ -443,8 +441,6 @@ bool Chat::deserialize(QDataStream &stream, int version)
if (!m_chatModel->deserialize(stream, version))
return false;
m_llmodel->setStateFromText(m_chatModel->text());
emit chatModelChanged();
return stream.status() == QDataStream::Ok;
}

View File

@ -7,6 +7,7 @@
#include "localdocsmodel.h" // IWYU pragma: keep
#include "modellist.h"
#include <QDateTime>
#include <QList>
#include <QObject>
#include <QQmlEngine>

View File

@ -2,6 +2,7 @@
#include "chat.h"
#include "chatapi.h"
#include "chatmodel.h"
#include "localdocs.h"
#include "mysettings.h"
#include "network.h"
@ -13,10 +14,14 @@
#include <QIODevice>
#include <QJsonDocument>
#include <QJsonObject>
#include <QJsonValue>
#include <QMap>
#include <QMutex>
#include <QMutexLocker>
#include <QRegularExpression>
#include <QSet>
#include <QStringList>
#include <QUrl>
#include <QWaitCondition>
#include <Qt>
#include <QtLogging>
@ -113,6 +118,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_reloadingToChangeVariant(false)
, m_processedSystemPrompt(false)
, m_restoreStateFromText(false)
, m_chatModel(parent->chatModel())
{
moveToThread(&m_llmThread);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
@ -1313,31 +1319,32 @@ void ChatLLM::processRestoreStateFromText()
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
auto it = m_stateFromText.begin();
while (it < m_stateFromText.end()) {
Q_ASSERT(m_chatModel);
m_chatModel->lock();
auto it = m_chatModel->begin();
while (it < m_chatModel->end()) {
auto &prompt = *it++;
Q_ASSERT(prompt.first == "Prompt: ");
Q_ASSERT(it < m_stateFromText.end());
Q_ASSERT(prompt.name == "Prompt: ");
Q_ASSERT(it < m_chatModel->end());
auto &response = *it++;
Q_ASSERT(response.first != "Prompt: ");
Q_ASSERT(response.name == "Response: ");
// FIXME(jared): this doesn't work well with the "regenerate" button since we are not incrementing
// m_promptTokens or m_promptResponseTokens
m_llModelInfo.model->prompt(
prompt.second.toStdString(), promptTemplate.toStdString(),
prompt.value.toStdString(), promptTemplate.toStdString(),
promptFunc, /*responseFunc*/ [](auto &&...) { return true; },
/*allowContextShift*/ true,
m_ctx,
/*special*/ false,
response.second.toUtf8().constData()
response.value.toUtf8().constData()
);
}
m_chatModel->unlock();
if (!m_stopGenerating) {
if (!m_stopGenerating)
m_restoreStateFromText = false;
m_stateFromText.clear();
}
m_restoringFromText = false;
emit restoringFromTextChanged();

View File

@ -11,11 +11,10 @@
#include <QFileInfo>
#include <QList>
#include <QObject>
#include <QPair>
#include <QPointer>
#include <QString>
#include <QThread>
#include <QVariantMap>
#include <QVector>
#include <QtGlobal>
#include <atomic>
@ -37,6 +36,7 @@ enum LLModelType {
};
class ChatLLM;
class ChatModel;
struct LLModelInfo {
std::unique_ptr<LLModel> model;
@ -151,7 +151,6 @@ public:
bool serialize(QDataStream &stream, int version, bool serializeKV);
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) { m_stateFromText = stateFromText; }
public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt);
@ -244,7 +243,7 @@ private:
// - an unload was queued during LLModel::restoreState()
// - the chat will be restored from text and hasn't been interacted with yet
bool m_pristineLoadedState = false;
QVector<QPair<QString, QString>> m_stateFromText;
QPointer<ChatModel> m_chatModel;
};
#endif // CHATLLM_H

View File

@ -45,6 +45,8 @@ public:
};
Q_DECLARE_METATYPE(ChatItem)
using ChatModelIterator = QList<ChatItem>::const_iterator;
class ChatModel : public QAbstractListModel
{
Q_OBJECT
@ -68,12 +70,14 @@ public:
int rowCount(const QModelIndex &parent = QModelIndex()) const override
{
QMutexLocker locker(&m_mutex);
Q_UNUSED(parent)
return m_chatItems.size();
}
QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override
{
QMutexLocker locker(&m_mutex);
if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size())
return QVariant();
@ -125,75 +129,112 @@ public:
ChatItem item;
item.name = name;
item.value = value;
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
m_chatItems.append(item);
m_mutex.lock();
const int count = m_chatItems.count();
m_mutex.unlock();
beginInsertRows(QModelIndex(), count, count);
{
QMutexLocker locker(&m_mutex);
m_chatItems.append(item);
}
endInsertRows();
emit countChanged();
}
void appendResponse(const QString &name)
{
m_mutex.lock();
const int count = m_chatItems.count();
m_mutex.unlock();
ChatItem item;
item.id = m_chatItems.count(); // This is only relevant for responses
item.id = count; // This is only relevant for responses
item.name = name;
item.currentResponse = true;
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
m_chatItems.append(item);
beginInsertRows(QModelIndex(), count, count);
{
QMutexLocker locker(&m_mutex);
m_chatItems.append(item);
}
endInsertRows();
emit countChanged();
}
Q_INVOKABLE void clear()
{
if (m_chatItems.isEmpty()) return;
{
QMutexLocker locker(&m_mutex);
if (m_chatItems.isEmpty()) return;
}
beginResetModel();
m_chatItems.clear();
{
QMutexLocker locker(&m_mutex);
m_chatItems.clear();
}
endResetModel();
emit countChanged();
}
Q_INVOKABLE ChatItem get(int index)
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return ChatItem();
return m_chatItems.at(index);
}
Q_INVOKABLE void updateCurrentResponse(int index, bool b)
{
if (index < 0 || index >= m_chatItems.size()) return;
bool changed = false;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.currentResponse != b) {
item.currentResponse = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole});
ChatItem &item = m_chatItems[index];
if (item.currentResponse != b) {
item.currentResponse = b;
changed = true;
}
}
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole});
}
Q_INVOKABLE void updateStopped(int index, bool b)
{
if (index < 0 || index >= m_chatItems.size()) return;
bool changed = false;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.stopped != b) {
item.stopped = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole});
ChatItem &item = m_chatItems[index];
if (item.stopped != b) {
item.stopped = b;
changed = true;
}
}
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole});
}
Q_INVOKABLE void updateValue(int index, const QString &value)
{
if (index < 0 || index >= m_chatItems.size()) return;
bool changed = false;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.value != value) {
item.value = value;
ChatItem &item = m_chatItems[index];
if (item.value != value) {
item.value = value;
changed = true;
}
}
if (changed) {
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole});
emit valueChanged(index, value);
}
}
QList<ResultInfo> consolidateSources(const QList<ResultInfo> &sources) {
static QList<ResultInfo> consolidateSources(const QList<ResultInfo> &sources) {
QMap<QString, ResultInfo> groupedData;
for (const ResultInfo &info : sources) {
if (groupedData.contains(info.file)) {
@ -208,53 +249,77 @@ public:
Q_INVOKABLE void updateSources(int index, const QList<ResultInfo> &sources)
{
if (index < 0 || index >= m_chatItems.size()) return;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
item.sources = sources;
item.consolidatedSources = consolidateSources(sources);
ChatItem &item = m_chatItems[index];
item.sources = sources;
item.consolidatedSources = consolidateSources(sources);
}
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole});
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole});
}
Q_INVOKABLE void updateThumbsUpState(int index, bool b)
{
if (index < 0 || index >= m_chatItems.size()) return;
bool changed = false;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.thumbsUpState != b) {
item.thumbsUpState = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole});
ChatItem &item = m_chatItems[index];
if (item.thumbsUpState != b) {
item.thumbsUpState = b;
changed = true;
}
}
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole});
}
Q_INVOKABLE void updateThumbsDownState(int index, bool b)
{
if (index < 0 || index >= m_chatItems.size()) return;
bool changed = false;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.thumbsDownState != b) {
item.thumbsDownState = b;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole});
ChatItem &item = m_chatItems[index];
if (item.thumbsDownState != b) {
item.thumbsDownState = b;
changed = true;
}
}
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole});
}
Q_INVOKABLE void updateNewResponse(int index, const QString &newResponse)
{
if (index < 0 || index >= m_chatItems.size()) return;
bool changed = false;
{
QMutexLocker locker(&m_mutex);
if (index < 0 || index >= m_chatItems.size()) return;
ChatItem &item = m_chatItems[index];
if (item.newResponse != newResponse) {
item.newResponse = newResponse;
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
ChatItem &item = m_chatItems[index];
if (item.newResponse != newResponse) {
item.newResponse = newResponse;
changed = true;
}
}
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
}
int count() const { return m_chatItems.size(); }
int count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); }
ChatModelIterator begin() const { return m_chatItems.begin(); }
ChatModelIterator end() const { return m_chatItems.end(); }
void lock() { m_mutex.lock(); }
void unlock() { m_mutex.unlock(); }
bool serialize(QDataStream &stream, int version) const
{
stream << count();
QMutexLocker locker(&m_mutex);
stream << int(m_chatItems.size());
for (const auto &c : m_chatItems) {
stream << c.id;
stream << c.name;
@ -442,28 +507,26 @@ public:
c.consolidatedSources = consolidateSources(sources);
}
}
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
m_chatItems.append(c);
m_mutex.lock();
const int count = m_chatItems.size();
m_mutex.unlock();
beginInsertRows(QModelIndex(), count, count);
{
QMutexLocker locker(&m_mutex);
m_chatItems.append(c);
}
endInsertRows();
}
emit countChanged();
return stream.status() == QDataStream::Ok;
}
QVector<QPair<QString, QString>> text() const
{
QVector<QPair<QString, QString>> result;
for (const auto &c : m_chatItems)
result << qMakePair(c.name, c.value);
return result;
}
Q_SIGNALS:
void countChanged();
void valueChanged(int index, const QString &value);
private:
mutable QMutex m_mutex;
QList<ChatItem> m_chatItems;
};