mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-21 21:19:08 +00:00
Make ChatModel threadsafe to support direct access by ChatLLM (#3018)
Signed-off-by: Adam Treat <treat.adam@gmail.com> Signed-off-by: Jared Van Bortel <jared@nomic.ai> Co-authored-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
ee67cca885
commit
c11b67dfcb
@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
||||
- Rebase llama.cpp on latest upstream as of September 26th ([#2998](https://github.com/nomic-ai/gpt4all/pull/2998))
|
||||
- Change the error message when a message is too long ([#3004](https://github.com/nomic-ai/gpt4all/pull/3004))
|
||||
- Simplify chatmodel to get rid of unnecessary field and bump chat version ([#3016](https://github.com/nomic-ai/gpt4all/pull/3016))
|
||||
- Allow ChatLLM to have direct access to ChatModel for restoring state from text ([#3018](https://github.com/nomic-ai/gpt4all/pull/3018))
|
||||
|
||||
### Fixed
|
||||
- Fix a crash when attempting to continue a chat loaded from disk ([#2995](https://github.com/nomic-ai/gpt4all/pull/2995))
|
||||
|
@ -6,15 +6,13 @@
|
||||
#include "server.h"
|
||||
|
||||
#include <QDataStream>
|
||||
#include <QDateTime>
|
||||
#include <QDebug>
|
||||
#include <QLatin1String>
|
||||
#include <QMap>
|
||||
#include <QString>
|
||||
#include <QStringList>
|
||||
#include <QTextStream>
|
||||
#include <QVariant>
|
||||
#include <Qt>
|
||||
#include <QtGlobal>
|
||||
#include <QtLogging>
|
||||
|
||||
#include <utility>
|
||||
@ -443,8 +441,6 @@ bool Chat::deserialize(QDataStream &stream, int version)
|
||||
if (!m_chatModel->deserialize(stream, version))
|
||||
return false;
|
||||
|
||||
m_llmodel->setStateFromText(m_chatModel->text());
|
||||
|
||||
emit chatModelChanged();
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include "localdocsmodel.h" // IWYU pragma: keep
|
||||
#include "modellist.h"
|
||||
|
||||
#include <QDateTime>
|
||||
#include <QList>
|
||||
#include <QObject>
|
||||
#include <QQmlEngine>
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "chat.h"
|
||||
#include "chatapi.h"
|
||||
#include "chatmodel.h"
|
||||
#include "localdocs.h"
|
||||
#include "mysettings.h"
|
||||
#include "network.h"
|
||||
@ -13,10 +14,14 @@
|
||||
#include <QIODevice>
|
||||
#include <QJsonDocument>
|
||||
#include <QJsonObject>
|
||||
#include <QJsonValue>
|
||||
#include <QMap>
|
||||
#include <QMutex>
|
||||
#include <QMutexLocker>
|
||||
#include <QRegularExpression>
|
||||
#include <QSet>
|
||||
#include <QStringList>
|
||||
#include <QUrl>
|
||||
#include <QWaitCondition>
|
||||
#include <Qt>
|
||||
#include <QtLogging>
|
||||
@ -113,6 +118,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
, m_reloadingToChangeVariant(false)
|
||||
, m_processedSystemPrompt(false)
|
||||
, m_restoreStateFromText(false)
|
||||
, m_chatModel(parent->chatModel())
|
||||
{
|
||||
moveToThread(&m_llmThread);
|
||||
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
|
||||
@ -1313,31 +1319,32 @@ void ChatLLM::processRestoreStateFromText()
|
||||
m_ctx.repeat_last_n = repeat_penalty_tokens;
|
||||
m_llModelInfo.model->setThreadCount(n_threads);
|
||||
|
||||
auto it = m_stateFromText.begin();
|
||||
while (it < m_stateFromText.end()) {
|
||||
Q_ASSERT(m_chatModel);
|
||||
m_chatModel->lock();
|
||||
auto it = m_chatModel->begin();
|
||||
while (it < m_chatModel->end()) {
|
||||
auto &prompt = *it++;
|
||||
Q_ASSERT(prompt.first == "Prompt: ");
|
||||
Q_ASSERT(it < m_stateFromText.end());
|
||||
Q_ASSERT(prompt.name == "Prompt: ");
|
||||
Q_ASSERT(it < m_chatModel->end());
|
||||
|
||||
auto &response = *it++;
|
||||
Q_ASSERT(response.first != "Prompt: ");
|
||||
Q_ASSERT(response.name == "Response: ");
|
||||
|
||||
// FIXME(jared): this doesn't work well with the "regenerate" button since we are not incrementing
|
||||
// m_promptTokens or m_promptResponseTokens
|
||||
m_llModelInfo.model->prompt(
|
||||
prompt.second.toStdString(), promptTemplate.toStdString(),
|
||||
prompt.value.toStdString(), promptTemplate.toStdString(),
|
||||
promptFunc, /*responseFunc*/ [](auto &&...) { return true; },
|
||||
/*allowContextShift*/ true,
|
||||
m_ctx,
|
||||
/*special*/ false,
|
||||
response.second.toUtf8().constData()
|
||||
response.value.toUtf8().constData()
|
||||
);
|
||||
}
|
||||
m_chatModel->unlock();
|
||||
|
||||
if (!m_stopGenerating) {
|
||||
if (!m_stopGenerating)
|
||||
m_restoreStateFromText = false;
|
||||
m_stateFromText.clear();
|
||||
}
|
||||
|
||||
m_restoringFromText = false;
|
||||
emit restoringFromTextChanged();
|
||||
|
@ -11,11 +11,10 @@
|
||||
#include <QFileInfo>
|
||||
#include <QList>
|
||||
#include <QObject>
|
||||
#include <QPair>
|
||||
#include <QPointer>
|
||||
#include <QString>
|
||||
#include <QThread>
|
||||
#include <QVariantMap>
|
||||
#include <QVector>
|
||||
#include <QtGlobal>
|
||||
|
||||
#include <atomic>
|
||||
@ -37,6 +36,7 @@ enum LLModelType {
|
||||
};
|
||||
|
||||
class ChatLLM;
|
||||
class ChatModel;
|
||||
|
||||
struct LLModelInfo {
|
||||
std::unique_ptr<LLModel> model;
|
||||
@ -151,7 +151,6 @@ public:
|
||||
|
||||
bool serialize(QDataStream &stream, int version, bool serializeKV);
|
||||
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
|
||||
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) { m_stateFromText = stateFromText; }
|
||||
|
||||
public Q_SLOTS:
|
||||
bool prompt(const QList<QString> &collectionList, const QString &prompt);
|
||||
@ -244,7 +243,7 @@ private:
|
||||
// - an unload was queued during LLModel::restoreState()
|
||||
// - the chat will be restored from text and hasn't been interacted with yet
|
||||
bool m_pristineLoadedState = false;
|
||||
QVector<QPair<QString, QString>> m_stateFromText;
|
||||
QPointer<ChatModel> m_chatModel;
|
||||
};
|
||||
|
||||
#endif // CHATLLM_H
|
||||
|
@ -45,6 +45,8 @@ public:
|
||||
};
|
||||
Q_DECLARE_METATYPE(ChatItem)
|
||||
|
||||
using ChatModelIterator = QList<ChatItem>::const_iterator;
|
||||
|
||||
class ChatModel : public QAbstractListModel
|
||||
{
|
||||
Q_OBJECT
|
||||
@ -68,12 +70,14 @@ public:
|
||||
|
||||
int rowCount(const QModelIndex &parent = QModelIndex()) const override
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
Q_UNUSED(parent)
|
||||
return m_chatItems.size();
|
||||
}
|
||||
|
||||
QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size())
|
||||
return QVariant();
|
||||
|
||||
@ -125,75 +129,112 @@ public:
|
||||
ChatItem item;
|
||||
item.name = name;
|
||||
item.value = value;
|
||||
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
|
||||
m_chatItems.append(item);
|
||||
m_mutex.lock();
|
||||
const int count = m_chatItems.count();
|
||||
m_mutex.unlock();
|
||||
beginInsertRows(QModelIndex(), count, count);
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
m_chatItems.append(item);
|
||||
}
|
||||
endInsertRows();
|
||||
emit countChanged();
|
||||
}
|
||||
|
||||
void appendResponse(const QString &name)
|
||||
{
|
||||
m_mutex.lock();
|
||||
const int count = m_chatItems.count();
|
||||
m_mutex.unlock();
|
||||
ChatItem item;
|
||||
item.id = m_chatItems.count(); // This is only relevant for responses
|
||||
item.id = count; // This is only relevant for responses
|
||||
item.name = name;
|
||||
item.currentResponse = true;
|
||||
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
|
||||
m_chatItems.append(item);
|
||||
beginInsertRows(QModelIndex(), count, count);
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
m_chatItems.append(item);
|
||||
}
|
||||
endInsertRows();
|
||||
emit countChanged();
|
||||
}
|
||||
|
||||
Q_INVOKABLE void clear()
|
||||
{
|
||||
if (m_chatItems.isEmpty()) return;
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
if (m_chatItems.isEmpty()) return;
|
||||
}
|
||||
|
||||
beginResetModel();
|
||||
m_chatItems.clear();
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
m_chatItems.clear();
|
||||
}
|
||||
endResetModel();
|
||||
emit countChanged();
|
||||
}
|
||||
|
||||
Q_INVOKABLE ChatItem get(int index)
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
if (index < 0 || index >= m_chatItems.size()) return ChatItem();
|
||||
return m_chatItems.at(index);
|
||||
}
|
||||
|
||||
Q_INVOKABLE void updateCurrentResponse(int index, bool b)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
bool changed = false;
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.currentResponse != b) {
|
||||
item.currentResponse = b;
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole});
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.currentResponse != b) {
|
||||
item.currentResponse = b;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {CurrentResponseRole});
|
||||
}
|
||||
|
||||
Q_INVOKABLE void updateStopped(int index, bool b)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
bool changed = false;
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.stopped != b) {
|
||||
item.stopped = b;
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole});
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.stopped != b) {
|
||||
item.stopped = b;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole});
|
||||
}
|
||||
|
||||
Q_INVOKABLE void updateValue(int index, const QString &value)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
bool changed = false;
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.value != value) {
|
||||
item.value = value;
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.value != value) {
|
||||
item.value = value;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
if (changed) {
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole});
|
||||
emit valueChanged(index, value);
|
||||
}
|
||||
}
|
||||
|
||||
QList<ResultInfo> consolidateSources(const QList<ResultInfo> &sources) {
|
||||
static QList<ResultInfo> consolidateSources(const QList<ResultInfo> &sources) {
|
||||
QMap<QString, ResultInfo> groupedData;
|
||||
for (const ResultInfo &info : sources) {
|
||||
if (groupedData.contains(info.file)) {
|
||||
@ -208,53 +249,77 @@ public:
|
||||
|
||||
Q_INVOKABLE void updateSources(int index, const QList<ResultInfo> &sources)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
|
||||
ChatItem &item = m_chatItems[index];
|
||||
item.sources = sources;
|
||||
item.consolidatedSources = consolidateSources(sources);
|
||||
ChatItem &item = m_chatItems[index];
|
||||
item.sources = sources;
|
||||
item.consolidatedSources = consolidateSources(sources);
|
||||
}
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {SourcesRole});
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ConsolidatedSourcesRole});
|
||||
}
|
||||
|
||||
Q_INVOKABLE void updateThumbsUpState(int index, bool b)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
bool changed = false;
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.thumbsUpState != b) {
|
||||
item.thumbsUpState = b;
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole});
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.thumbsUpState != b) {
|
||||
item.thumbsUpState = b;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsUpStateRole});
|
||||
}
|
||||
|
||||
Q_INVOKABLE void updateThumbsDownState(int index, bool b)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
bool changed = false;
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.thumbsDownState != b) {
|
||||
item.thumbsDownState = b;
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole});
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.thumbsDownState != b) {
|
||||
item.thumbsDownState = b;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ThumbsDownStateRole});
|
||||
}
|
||||
|
||||
Q_INVOKABLE void updateNewResponse(int index, const QString &newResponse)
|
||||
{
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
bool changed = false;
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
if (index < 0 || index >= m_chatItems.size()) return;
|
||||
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.newResponse != newResponse) {
|
||||
item.newResponse = newResponse;
|
||||
emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
|
||||
ChatItem &item = m_chatItems[index];
|
||||
if (item.newResponse != newResponse) {
|
||||
item.newResponse = newResponse;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole});
|
||||
}
|
||||
|
||||
int count() const { return m_chatItems.size(); }
|
||||
int count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); }
|
||||
|
||||
ChatModelIterator begin() const { return m_chatItems.begin(); }
|
||||
ChatModelIterator end() const { return m_chatItems.end(); }
|
||||
void lock() { m_mutex.lock(); }
|
||||
void unlock() { m_mutex.unlock(); }
|
||||
|
||||
bool serialize(QDataStream &stream, int version) const
|
||||
{
|
||||
stream << count();
|
||||
QMutexLocker locker(&m_mutex);
|
||||
stream << int(m_chatItems.size());
|
||||
for (const auto &c : m_chatItems) {
|
||||
stream << c.id;
|
||||
stream << c.name;
|
||||
@ -442,28 +507,26 @@ public:
|
||||
c.consolidatedSources = consolidateSources(sources);
|
||||
}
|
||||
}
|
||||
beginInsertRows(QModelIndex(), m_chatItems.size(), m_chatItems.size());
|
||||
m_chatItems.append(c);
|
||||
m_mutex.lock();
|
||||
const int count = m_chatItems.size();
|
||||
m_mutex.unlock();
|
||||
beginInsertRows(QModelIndex(), count, count);
|
||||
{
|
||||
QMutexLocker locker(&m_mutex);
|
||||
m_chatItems.append(c);
|
||||
}
|
||||
endInsertRows();
|
||||
}
|
||||
emit countChanged();
|
||||
return stream.status() == QDataStream::Ok;
|
||||
}
|
||||
|
||||
QVector<QPair<QString, QString>> text() const
|
||||
{
|
||||
QVector<QPair<QString, QString>> result;
|
||||
for (const auto &c : m_chatItems)
|
||||
result << qMakePair(c.name, c.value);
|
||||
return result;
|
||||
}
|
||||
|
||||
Q_SIGNALS:
|
||||
void countChanged();
|
||||
void valueChanged(int index, const QString &value);
|
||||
|
||||
private:
|
||||
|
||||
mutable QMutex m_mutex;
|
||||
QList<ChatItem> m_chatItems;
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user