Get rid of last blocking operations and make the chat/llm thread safe.

This commit is contained in:
Adam Treat 2023-06-20 16:14:30 -04:00 committed by AT
parent 84ec4311e9
commit c8a590bc6f
6 changed files with 60 additions and 72 deletions

View File

@ -16,6 +16,7 @@ Chat::Chat(QObject *parent)
, m_llmodel(new ChatLLM(this)) , m_llmodel(new ChatLLM(this))
, m_isServer(false) , m_isServer(false)
, m_shouldDeleteLater(false) , m_shouldDeleteLater(false)
, m_isModelLoaded(false)
{ {
connectLLM(); connectLLM();
} }
@ -31,6 +32,7 @@ Chat::Chat(bool isServer, QObject *parent)
, m_llmodel(new Server(this)) , m_llmodel(new Server(this))
, m_isServer(true) , m_isServer(true)
, m_shouldDeleteLater(false) , m_shouldDeleteLater(false)
, m_isModelLoaded(false)
{ {
connectLLM(); connectLLM();
} }
@ -55,12 +57,10 @@ void Chat::connectLLM()
connect(m_watcher, &QFileSystemWatcher::directoryChanged, this, &Chat::handleModelListChanged); connect(m_watcher, &QFileSystemWatcher::directoryChanged, this, &Chat::handleModelListChanged);
// Should be in different threads // Should be in different threads
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::handleModelLoadedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::handleModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelNameChanged, this, &Chat::handleModelNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
@ -73,11 +73,8 @@ void Chat::connectLLM()
connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection); connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::QueuedConnection); connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::QueuedConnection);
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::QueuedConnection);
// The following are blocking operations and will block the gui thread, therefore must be fast connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::QueuedConnection);
// to respond to
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::BlockingQueuedConnection);
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::BlockingQueuedConnection);
emit defaultModelChanged(modelList().first()); emit defaultModelChanged(modelList().first());
} }
@ -87,7 +84,7 @@ void Chat::reset()
stopGenerating(); stopGenerating();
// Erase our current on disk representation as we're completely resetting the chat along with id // Erase our current on disk representation as we're completely resetting the chat along with id
LLM::globalInstance()->chatListModel()->removeChatFile(this); LLM::globalInstance()->chatListModel()->removeChatFile(this);
emit resetContextRequested(); // blocking queued connection emit resetContextRequested();
m_id = Network::globalInstance()->generateUniqueId(); m_id = Network::globalInstance()->generateUniqueId();
emit idChanged(m_id); emit idChanged(m_id);
// NOTE: We deliberately do no reset the name or creation date to indictate that this was originally // NOTE: We deliberately do no reset the name or creation date to indictate that this was originally
@ -105,7 +102,7 @@ void Chat::reset()
bool Chat::isModelLoaded() const bool Chat::isModelLoaded() const
{ {
return m_llmodel->isModelLoaded(); return m_isModelLoaded;
} }
void Chat::resetResponseState() void Chat::resetResponseState()
@ -154,7 +151,7 @@ void Chat::stopGenerating()
QString Chat::response() const QString Chat::response() const
{ {
return m_llmodel->response(); return m_response;
} }
QString Chat::responseState() const QString Chat::responseState() const
@ -170,22 +167,29 @@ QString Chat::responseState() const
return QString(); return QString();
} }
void Chat::handleResponseChanged() void Chat::handleResponseChanged(const QString &response)
{ {
if (m_responseState != Chat::ResponseGeneration) { if (m_responseState != Chat::ResponseGeneration) {
m_responseState = Chat::ResponseGeneration; m_responseState = Chat::ResponseGeneration;
emit responseStateChanged(); emit responseStateChanged();
} }
m_response = response;
const int index = m_chatModel->count() - 1; const int index = m_chatModel->count() - 1;
m_chatModel->updateValue(index, response()); m_chatModel->updateValue(index, this->response());
emit responseChanged(); emit responseChanged();
} }
void Chat::handleModelLoadedChanged() void Chat::handleModelLoadedChanged(bool loaded)
{ {
if (m_shouldDeleteLater) if (m_shouldDeleteLater)
deleteLater(); deleteLater();
if (loaded == m_isModelLoaded)
return;
m_isModelLoaded = loaded;
emit isModelLoadedChanged();
} }
void Chat::promptProcessing() void Chat::promptProcessing()
@ -241,7 +245,7 @@ void Chat::responseStopped()
m_responseState = Chat::ResponseStopped; m_responseState = Chat::ResponseStopped;
emit responseInProgressChanged(); emit responseInProgressChanged();
emit responseStateChanged(); emit responseStateChanged();
if (m_llmodel->generatedName().isEmpty()) if (m_generatedName.isEmpty())
emit generateNameRequested(); emit generateNameRequested();
if (chatModel()->count() < 3) if (chatModel()->count() < 3)
Network::globalInstance()->sendChatStarted(); Network::globalInstance()->sendChatStarted();
@ -249,15 +253,18 @@ void Chat::responseStopped()
QString Chat::modelName() const QString Chat::modelName() const
{ {
return m_llmodel->modelName(); return m_modelName;
} }
void Chat::setModelName(const QString &modelName) void Chat::setModelName(const QString &modelName)
{ {
// doesn't block but will unload old model and load new one which the gui can see through changes if (m_modelName == modelName)
// to the isModelLoaded property return;
m_modelLoadingError = QString(); m_modelLoadingError = QString();
emit modelLoadingErrorChanged(); emit modelLoadingErrorChanged();
m_modelName = modelName;
emit modelNameChanged();
emit modelNameChangeRequested(modelName); emit modelNameChangeRequested(modelName);
} }
@ -267,7 +274,7 @@ void Chat::newPromptResponsePair(const QString &prompt)
m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false); m_chatModel->updateCurrentResponse(m_chatModel->count() - 1, false);
m_chatModel->appendPrompt(tr("Prompt: "), prompt); m_chatModel->appendPrompt(tr("Prompt: "), prompt);
m_chatModel->appendResponse(tr("Response: "), prompt); m_chatModel->appendResponse(tr("Response: "), prompt);
emit resetResponseRequested(); // blocking queued connection emit resetResponseRequested();
} }
void Chat::serverNewPromptResponsePair(const QString &prompt) void Chat::serverNewPromptResponsePair(const QString &prompt)
@ -320,11 +327,11 @@ void Chat::reloadModel()
m_llmodel->setShouldBeLoaded(true); m_llmodel->setShouldBeLoaded(true);
} }
void Chat::generatedNameChanged() void Chat::generatedNameChanged(const QString &name)
{ {
// Only use the first three words maximum and remove newlines and extra spaces // Only use the first three words maximum and remove newlines and extra spaces
QString gen = m_llmodel->generatedName().simplified(); m_generatedName = name.simplified();
QStringList words = gen.split(' ', Qt::SkipEmptyParts); QStringList words = m_generatedName.split(' ', Qt::SkipEmptyParts);
int wordCount = qMin(3, words.size()); int wordCount = qMin(3, words.size());
m_name = words.mid(0, wordCount).join(' '); m_name = words.mid(0, wordCount).join(' ');
emit nameChanged(); emit nameChanged();
@ -336,12 +343,6 @@ void Chat::handleRecalculating()
emit recalcChanged(); emit recalcChanged();
} }
void Chat::handleModelNameChanged()
{
m_savedModelName = modelName();
emit modelNameChanged();
}
void Chat::handleModelLoadingError(const QString &error) void Chat::handleModelLoadingError(const QString &error)
{ {
qWarning() << "ERROR:" << qPrintable(error) << "id" << id(); qWarning() << "ERROR:" << qPrintable(error) << "id" << id();
@ -366,7 +367,7 @@ bool Chat::serialize(QDataStream &stream, int version) const
stream << m_id; stream << m_id;
stream << m_name; stream << m_name;
stream << m_userName; stream << m_userName;
stream << m_savedModelName; stream << m_modelName;
if (version > 2) if (version > 2)
stream << m_collections; stream << m_collections;
if (!m_llmodel->serialize(stream, version)) if (!m_llmodel->serialize(stream, version))
@ -384,16 +385,17 @@ bool Chat::deserialize(QDataStream &stream, int version)
stream >> m_name; stream >> m_name;
stream >> m_userName; stream >> m_userName;
emit nameChanged(); emit nameChanged();
stream >> m_savedModelName; stream >> m_modelName;
emit modelNameChanged();
// Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so // Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
// unfortunately, we cannot deserialize these // unfortunately, we cannot deserialize these
if (version < 2 && m_savedModelName.contains("gpt4all-j")) if (version < 2 && m_modelName.contains("gpt4all-j"))
return false; return false;
if (version > 2) { if (version > 2) {
stream >> m_collections; stream >> m_collections;
emit collectionListChanged(m_collections); emit collectionListChanged(m_collections);
} }
m_llmodel->setModelName(m_savedModelName); m_llmodel->setModelName(m_modelName);
if (!m_llmodel->deserialize(stream, version)) if (!m_llmodel->deserialize(stream, version))
return false; return false;
if (!m_chatModel->deserialize(stream, version)) if (!m_chatModel->deserialize(stream, version))

View File

@ -125,13 +125,12 @@ Q_SIGNALS:
void defaultModelChanged(const QString &defaultModel); void defaultModelChanged(const QString &defaultModel);
private Q_SLOTS: private Q_SLOTS:
void handleResponseChanged(); void handleResponseChanged(const QString &response);
void handleModelLoadedChanged(); void handleModelLoadedChanged(bool);
void promptProcessing(); void promptProcessing();
void responseStopped(); void responseStopped();
void generatedNameChanged(); void generatedNameChanged(const QString &name);
void handleRecalculating(); void handleRecalculating();
void handleModelNameChanged();
void handleModelLoadingError(const QString &error); void handleModelLoadingError(const QString &error);
void handleTokenSpeedChanged(const QString &tokenSpeed); void handleTokenSpeedChanged(const QString &tokenSpeed);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results); void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
@ -141,10 +140,12 @@ private Q_SLOTS:
private: private:
QString m_id; QString m_id;
QString m_name; QString m_name;
QString m_generatedName;
QString m_userName; QString m_userName;
QString m_savedModelName; QString m_modelName;
QString m_modelLoadingError; QString m_modelLoadingError;
QString m_tokenSpeed; QString m_tokenSpeed;
QString m_response;
QList<QString> m_collections; QList<QString> m_collections;
ChatModel *m_chatModel; ChatModel *m_chatModel;
bool m_responseInProgress; bool m_responseInProgress;
@ -154,6 +155,7 @@ private:
QList<ResultInfo> m_databaseResults; QList<ResultInfo> m_databaseResults;
bool m_isServer; bool m_isServer;
bool m_shouldDeleteLater; bool m_shouldDeleteLater;
bool m_isModelLoaded;
QFileSystemWatcher *m_watcher; QFileSystemWatcher *m_watcher;
}; };

View File

@ -233,7 +233,6 @@ void ChatListModel::restoreChat(Chat *chat)
{ {
chat->setParent(this); chat->setParent(this);
connect(chat, &Chat::nameChanged, this, &ChatListModel::nameChanged); connect(chat, &Chat::nameChanged, this, &ChatListModel::nameChanged);
connect(chat, &Chat::modelLoadingErrorChanged, this, &ChatListModel::handleModelLoadingError);
if (m_dummyChat) { if (m_dummyChat) {
beginResetModel(); beginResetModel();

View File

@ -122,8 +122,6 @@ public:
this, &ChatListModel::newChatCountChanged); this, &ChatListModel::newChatCountChanged);
connect(m_newChat, &Chat::nameChanged, connect(m_newChat, &Chat::nameChanged,
this, &ChatListModel::nameChanged); this, &ChatListModel::nameChanged);
connect(m_newChat, &Chat::modelLoadingError,
this, &ChatListModel::handleModelLoadingError);
setCurrentChat(m_newChat); setCurrentChat(m_newChat);
} }
@ -227,12 +225,6 @@ private Q_SLOTS:
emit dataChanged(index, index, {NameRole}); emit dataChanged(index, index, {NameRole});
} }
void handleModelLoadingError()
{
Chat *chat = qobject_cast<Chat *>(sender());
removeChat(chat);
}
void printChats() void printChats()
{ {
for (auto c : m_chats) { for (auto c : m_chats) {

View File

@ -11,7 +11,6 @@
#include <QProcess> #include <QProcess>
#include <QResource> #include <QResource>
#include <QSettings> #include <QSettings>
#include <fstream>
//#define DEBUG //#define DEBUG
//#define DEBUG_MODEL_LOADING //#define DEBUG_MODEL_LOADING
@ -154,7 +153,7 @@ bool ChatLLM::loadModel(const QString &modelName)
// to provide an overview of what we're doing here. // to provide an overview of what we're doing here.
// We're already loaded with this model // We're already loaded with this model
if (isModelLoaded() && m_modelName == modelName) if (isModelLoaded() && this->modelName() == modelName)
return true; return true;
bool isChatGPT = modelName.startsWith("chatgpt-"); bool isChatGPT = modelName.startsWith("chatgpt-");
@ -170,7 +169,7 @@ bool ChatLLM::loadModel(const QString &modelName)
#endif #endif
delete m_modelInfo.model; delete m_modelInfo.model;
m_modelInfo.model = nullptr; m_modelInfo.model = nullptr;
emit isModelLoadedChanged(); emit isModelLoadedChanged(false);
} else if (!m_isServer) { } else if (!m_isServer) {
// This is a blocking call that tries to retrieve the model we need from the model store. // This is a blocking call that tries to retrieve the model we need from the model store.
// If it succeeds, then we just have to restore state. If the store has never had a model // If it succeeds, then we just have to restore state. If the store has never had a model
@ -188,7 +187,7 @@ bool ChatLLM::loadModel(const QString &modelName)
#endif #endif
LLModelStore::globalInstance()->releaseModel(m_modelInfo); LLModelStore::globalInstance()->releaseModel(m_modelInfo);
m_modelInfo = LLModelInfo(); m_modelInfo = LLModelInfo();
emit isModelLoadedChanged(); emit isModelLoadedChanged(false);
return false; return false;
} }
@ -198,7 +197,7 @@ bool ChatLLM::loadModel(const QString &modelName)
qDebug() << "store had our model" << m_llmThread.objectName() << m_modelInfo.model; qDebug() << "store had our model" << m_llmThread.objectName() << m_modelInfo.model;
#endif #endif
restoreState(); restoreState();
emit isModelLoadedChanged(); emit isModelLoadedChanged(true);
return true; return true;
} else { } else {
// Release the memory since we have to switch to a different model. // Release the memory since we have to switch to a different model.
@ -273,7 +272,7 @@ bool ChatLLM::loadModel(const QString &modelName)
qDebug() << "modelLoadedChanged" << m_llmThread.objectName(); qDebug() << "modelLoadedChanged" << m_llmThread.objectName();
fflush(stdout); fflush(stdout);
#endif #endif
emit isModelLoadedChanged(); emit isModelLoadedChanged(isModelLoaded());
static bool isFirstLoad = true; static bool isFirstLoad = true;
if (isFirstLoad) { if (isFirstLoad) {
@ -316,7 +315,7 @@ void ChatLLM::regenerateResponse()
m_promptResponseTokens = 0; m_promptResponseTokens = 0;
m_promptTokens = 0; m_promptTokens = 0;
m_response = std::string(); m_response = std::string();
emit responseChanged(); emit responseChanged(QString::fromStdString(m_response));
} }
void ChatLLM::resetResponse() void ChatLLM::resetResponse()
@ -324,7 +323,7 @@ void ChatLLM::resetResponse()
m_promptTokens = 0; m_promptTokens = 0;
m_promptResponseTokens = 0; m_promptResponseTokens = 0;
m_response = std::string(); m_response = std::string();
emit responseChanged(); emit responseChanged(QString::fromStdString(m_response));
} }
void ChatLLM::resetContext() void ChatLLM::resetContext()
@ -397,7 +396,7 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
// check for error // check for error
if (token < 0) { if (token < 0) {
m_response.append(response); m_response.append(response);
emit responseChanged(); emit responseChanged(QString::fromStdString(m_response));
return false; return false;
} }
@ -407,7 +406,7 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
m_timer->inc(); m_timer->inc();
Q_ASSERT(!response.empty()); Q_ASSERT(!response.empty());
m_response.append(response); m_response.append(response);
emit responseChanged(); emit responseChanged(QString::fromStdString(m_response));
return !m_stopGenerating; return !m_stopGenerating;
} }
@ -470,7 +469,7 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt
std::string trimmed = trim_whitespace(m_response); std::string trimmed = trim_whitespace(m_response);
if (trimmed != m_response) { if (trimmed != m_response) {
m_response = trimmed; m_response = trimmed;
emit responseChanged(); emit responseChanged(QString::fromStdString(m_response));
} }
emit responseStopped(); emit responseStopped();
return true; return true;
@ -510,7 +509,7 @@ void ChatLLM::unloadModel()
#endif #endif
LLModelStore::globalInstance()->releaseModel(m_modelInfo); LLModelStore::globalInstance()->releaseModel(m_modelInfo);
m_modelInfo = LLModelInfo(); m_modelInfo = LLModelInfo();
emit isModelLoadedChanged(); emit isModelLoadedChanged(false);
} }
void ChatLLM::reloadModel() void ChatLLM::reloadModel()
@ -521,11 +520,11 @@ void ChatLLM::reloadModel()
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "reloadModel" << m_llmThread.objectName() << m_modelInfo.model; qDebug() << "reloadModel" << m_llmThread.objectName() << m_modelInfo.model;
#endif #endif
if (m_modelName.isEmpty()) { const QString m = modelName();
if (m.isEmpty())
loadDefaultModel(); loadDefaultModel();
} else { else
loadModel(m_modelName); loadModel(m);
}
} }
void ChatLLM::generateName() void ChatLLM::generateName()
@ -554,7 +553,7 @@ void ChatLLM::generateName()
std::string trimmed = trim_whitespace(m_nameResponse); std::string trimmed = trim_whitespace(m_nameResponse);
if (trimmed != m_nameResponse) { if (trimmed != m_nameResponse) {
m_nameResponse = trimmed; m_nameResponse = trimmed;
emit generatedNameChanged(); emit generatedNameChanged(QString::fromStdString(m_nameResponse));
} }
} }
@ -580,7 +579,7 @@ bool ChatLLM::handleNameResponse(int32_t token, const std::string &response)
Q_UNUSED(token); Q_UNUSED(token);
m_nameResponse.append(response); m_nameResponse.append(response);
emit generatedNameChanged(); emit generatedNameChanged(QString::fromStdString(m_nameResponse));
QString gen = QString::fromStdString(m_nameResponse).simplified(); QString gen = QString::fromStdString(m_nameResponse).simplified();
QStringList words = gen.split(' ', Qt::SkipEmptyParts); QStringList words = gen.split(' ', Qt::SkipEmptyParts);
return words.size() <= 3; return words.size() <= 3;

View File

@ -116,16 +116,16 @@ public Q_SLOTS:
void handleThreadStarted(); void handleThreadStarted();
Q_SIGNALS: Q_SIGNALS:
void isModelLoadedChanged(); void isModelLoadedChanged(bool);
void modelLoadingError(const QString &error); void modelLoadingError(const QString &error);
void responseChanged(); void responseChanged(const QString &response);
void promptProcessing(); void promptProcessing();
void responseStopped(); void responseStopped();
void modelNameChanged(); void modelNameChanged();
void recalcChanged(); void recalcChanged();
void sendStartup(); void sendStartup();
void sendModelLoaded(); void sendModelLoaded();
void generatedNameChanged(); void generatedNameChanged(const QString &name);
void stateChanged(); void stateChanged();
void threadStarted(); void threadStarted();
void shouldBeLoadedChanged(); void shouldBeLoadedChanged();
@ -144,22 +144,16 @@ protected:
void restoreState(); void restoreState();
protected: protected:
// The following are all accessed by multiple threads and are thus guarded with thread protection
// mechanisms
LLModel::PromptContext m_ctx; LLModel::PromptContext m_ctx;
quint32 m_promptTokens; quint32 m_promptTokens;
quint32 m_promptResponseTokens; quint32 m_promptResponseTokens;
private: private:
// The following are all accessed by multiple threads and are thus guarded with thread protection
// mechanisms
std::string m_response; std::string m_response;
std::string m_nameResponse; std::string m_nameResponse;
LLModelInfo m_modelInfo; LLModelInfo m_modelInfo;
LLModelType m_modelType; LLModelType m_modelType;
QString m_modelName; QString m_modelName;
// The following are only accessed by this thread
QString m_defaultModel; QString m_defaultModel;
TokenTimer *m_timer; TokenTimer *m_timer;
QByteArray m_state; QByteArray m_state;