From 7e1e00f3311c596ef807e9c139bfae7cef3e2c49 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Wed, 15 May 2024 14:07:03 -0400 Subject: [PATCH] chat: fix issues with quickly switching between multiple chats (#2343) * prevent load progress from getting out of sync with the current chat * fix memory leak on exit if the LLModelStore contains a model * do not report cancellation as a failure in console/Mixpanel * show "waiting for model" separately from "switching context" in UI * do not show lower "reload" button on error * skip context switch if unload is pending * skip unnecessary calls to LLModel::saveState Signed-off-by: Jared Van Bortel --- gpt4all-chat/chat.cpp | 41 +++---- gpt4all-chat/chat.h | 17 ++- gpt4all-chat/chatlistmodel.h | 6 +- gpt4all-chat/chatllm.cpp | 194 ++++++++++++++++++++-------------- gpt4all-chat/chatllm.h | 19 ++-- gpt4all-chat/qml/ChatView.qml | 45 +++----- 6 files changed, 179 insertions(+), 143 deletions(-) diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 81c9b234..e8bbfa8b 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -54,7 +54,7 @@ void Chat::connectLLM() connect(m_llmodel, &ChatLLM::reportFallbackReason, this, &Chat::handleFallbackReasonChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::trySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection); + connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::modelChangeRequested, m_llmodel, &ChatLLM::modelChangeRequested, Qt::QueuedConnection); @@ -95,16 +95,6 @@ void Chat::processSystemPrompt() emit processSystemPromptRequested(); } -bool Chat::isModelLoaded() const -{ - return m_modelLoadingPercentage == 1.0f; -} - -float Chat::modelLoadingPercentage() const -{ - return m_modelLoadingPercentage; -} - void Chat::resetResponseState() { if (m_responseInProgress && m_responseState == Chat::LocalDocsRetrieval) @@ -167,9 +157,16 @@ void Chat::handleModelLoadingPercentageChanged(float loadingPercentage) if (loadingPercentage == m_modelLoadingPercentage) return; + bool wasLoading = isCurrentlyLoading(); + bool wasLoaded = isModelLoaded(); + m_modelLoadingPercentage = loadingPercentage; emit modelLoadingPercentageChanged(); - if (m_modelLoadingPercentage == 1.0f || m_modelLoadingPercentage == 0.0f) + + if (isCurrentlyLoading() != wasLoading) + emit isCurrentlyLoadingChanged(); + + if (isModelLoaded() != wasLoaded) emit isModelLoadedChanged(); } @@ -247,10 +244,6 @@ void Chat::setModelInfo(const ModelInfo &modelInfo) if (m_modelInfo == modelInfo && isModelLoaded()) return; - m_modelLoadingPercentage = std::numeric_limits::min(); // small non-zero positive value - emit isModelLoadedChanged(); - m_modelLoadingError = QString(); - emit modelLoadingErrorChanged(); m_modelInfo = modelInfo; emit modelInfoChanged(); emit modelChangeRequested(modelInfo); @@ -320,8 +313,9 @@ void Chat::forceReloadModel() void Chat::trySwitchContextOfLoadedModel() { - emit trySwitchContextOfLoadedModelAttempted(); - m_llmodel->setShouldTrySwitchContext(true); + m_trySwitchContextInProgress = 1; + emit trySwitchContextInProgressChanged(); + m_llmodel->requestTrySwitchContext(); } void Chat::generatedNameChanged(const QString &name) @@ -342,8 +336,10 @@ void Chat::handleRecalculating() void Chat::handleModelLoadingError(const QString &error) { - auto stream = qWarning().noquote() << "ERROR:" << error << "id"; - stream.quote() << id(); + if (!error.isEmpty()) { + auto stream = qWarning().noquote() << "ERROR:" << error << "id"; + stream.quote() << id(); + } m_modelLoadingError = error; emit modelLoadingErrorChanged(); } @@ -380,6 +376,11 @@ void Chat::handleModelInfoChanged(const ModelInfo &modelInfo) emit modelInfoChanged(); } +void Chat::handleTrySwitchContextOfLoadedModelCompleted(int value) { + m_trySwitchContextInProgress = value; + emit trySwitchContextInProgressChanged(); +} + bool Chat::serialize(QDataStream &stream, int version) const { stream << m_creationDate; diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index 9e6fc8bd..0859cb8b 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -17,6 +17,7 @@ class Chat : public QObject Q_PROPERTY(QString name READ name WRITE setName NOTIFY nameChanged) Q_PROPERTY(ChatModel *chatModel READ chatModel NOTIFY chatModelChanged) Q_PROPERTY(bool isModelLoaded READ isModelLoaded NOTIFY isModelLoadedChanged) + Q_PROPERTY(bool isCurrentlyLoading READ isCurrentlyLoading NOTIFY isCurrentlyLoadingChanged) Q_PROPERTY(float modelLoadingPercentage READ modelLoadingPercentage NOTIFY modelLoadingPercentageChanged) Q_PROPERTY(QString response READ response NOTIFY responseChanged) Q_PROPERTY(ModelInfo modelInfo READ modelInfo WRITE setModelInfo NOTIFY modelInfoChanged) @@ -30,6 +31,8 @@ class Chat : public QObject Q_PROPERTY(QString device READ device NOTIFY deviceChanged); Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY fallbackReasonChanged); Q_PROPERTY(LocalDocsCollectionsModel *collectionModel READ collectionModel NOTIFY collectionModelChanged) + // 0=no, 1=waiting, 2=working + Q_PROPERTY(int trySwitchContextInProgress READ trySwitchContextInProgress NOTIFY trySwitchContextInProgressChanged) QML_ELEMENT QML_UNCREATABLE("Only creatable from c++!") @@ -62,8 +65,9 @@ public: Q_INVOKABLE void reset(); Q_INVOKABLE void processSystemPrompt(); - Q_INVOKABLE bool isModelLoaded() const; - Q_INVOKABLE float modelLoadingPercentage() const; + bool isModelLoaded() const { return m_modelLoadingPercentage == 1.0f; } + bool isCurrentlyLoading() const { return m_modelLoadingPercentage > 0.0f && m_modelLoadingPercentage < 1.0f; } + float modelLoadingPercentage() const { return m_modelLoadingPercentage; } Q_INVOKABLE void prompt(const QString &prompt); Q_INVOKABLE void regenerateResponse(); Q_INVOKABLE void stopGenerating(); @@ -105,6 +109,8 @@ public: QString device() const { return m_device; } QString fallbackReason() const { return m_fallbackReason; } + int trySwitchContextInProgress() const { return m_trySwitchContextInProgress; } + public Q_SLOTS: void serverNewPromptResponsePair(const QString &prompt); @@ -113,6 +119,7 @@ Q_SIGNALS: void nameChanged(); void chatModelChanged(); void isModelLoadedChanged(); + void isCurrentlyLoadingChanged(); void modelLoadingPercentageChanged(); void modelLoadingWarning(const QString &warning); void responseChanged(); @@ -136,8 +143,7 @@ Q_SIGNALS: void deviceChanged(); void fallbackReasonChanged(); void collectionModelChanged(); - void trySwitchContextOfLoadedModelAttempted(); - void trySwitchContextOfLoadedModelCompleted(bool); + void trySwitchContextInProgressChanged(); private Q_SLOTS: void handleResponseChanged(const QString &response); @@ -152,6 +158,7 @@ private Q_SLOTS: void handleFallbackReasonChanged(const QString &device); void handleDatabaseResultsChanged(const QList &results); void handleModelInfoChanged(const ModelInfo &modelInfo); + void handleTrySwitchContextOfLoadedModelCompleted(int value); private: QString m_id; @@ -176,6 +183,8 @@ private: float m_modelLoadingPercentage = 0.0f; LocalDocsCollectionsModel *m_collectionModel; bool m_firstResponse = true; + int m_trySwitchContextInProgress = 0; + bool m_isCurrentlyLoading = false; }; #endif // CHAT_H diff --git a/gpt4all-chat/chatlistmodel.h b/gpt4all-chat/chatlistmodel.h index 7460bc80..709391e0 100644 --- a/gpt4all-chat/chatlistmodel.h +++ b/gpt4all-chat/chatlistmodel.h @@ -195,7 +195,11 @@ public: int count() const { return m_chats.size(); } // stop ChatLLM threads for clean shutdown - void destroyChats() { for (auto *chat: m_chats) { chat->destroy(); } } + void destroyChats() + { + for (auto *chat: m_chats) { chat->destroy(); } + ChatLLM::destroyStore(); + } void removeChatFile(Chat *chat) const; Q_INVOKABLE void saveChats(); diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index e54301c8..333cdaa9 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -30,16 +30,17 @@ public: static LLModelStore *globalInstance(); LLModelInfo acquireModel(); // will block until llmodel is ready - void releaseModel(const LLModelInfo &info); // must be called when you are done + void releaseModel(LLModelInfo &&info); // must be called when you are done + void destroy(); private: LLModelStore() { // seed with empty model - m_availableModels.append(LLModelInfo()); + m_availableModel = LLModelInfo(); } ~LLModelStore() {} - QVector m_availableModels; + std::optional m_availableModel; QMutex m_mutex; QWaitCondition m_condition; friend class MyLLModelStore; @@ -55,19 +56,27 @@ LLModelStore *LLModelStore::globalInstance() LLModelInfo LLModelStore::acquireModel() { QMutexLocker locker(&m_mutex); - while (m_availableModels.isEmpty()) + while (!m_availableModel) m_condition.wait(locker.mutex()); - return m_availableModels.takeFirst(); + auto first = std::move(*m_availableModel); + m_availableModel.reset(); + return first; } -void LLModelStore::releaseModel(const LLModelInfo &info) +void LLModelStore::releaseModel(LLModelInfo &&info) { QMutexLocker locker(&m_mutex); - m_availableModels.append(info); - Q_ASSERT(m_availableModels.count() < 2); + Q_ASSERT(!m_availableModel); + m_availableModel = std::move(info); m_condition.wakeAll(); } +void LLModelStore::destroy() +{ + QMutexLocker locker(&m_mutex); + m_availableModel.reset(); +} + ChatLLM::ChatLLM(Chat *parent, bool isServer) : QObject{nullptr} , m_promptResponseTokens(0) @@ -76,7 +85,6 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_shouldBeLoaded(false) , m_forceUnloadModel(false) , m_markedForDeletion(false) - , m_shouldTrySwitchContext(false) , m_stopGenerating(false) , m_timer(nullptr) , m_isServer(isServer) @@ -88,7 +96,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) moveToThread(&m_llmThread); connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, Qt::QueuedConnection); // explicitly queued - connect(this, &ChatLLM::shouldTrySwitchContextChanged, this, &ChatLLM::handleShouldTrySwitchContextChanged, + connect(this, &ChatLLM::trySwitchContextRequested, this, &ChatLLM::trySwitchContextOfLoadedModel, Qt::QueuedConnection); // explicitly queued connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); @@ -108,7 +116,8 @@ ChatLLM::~ChatLLM() destroy(); } -void ChatLLM::destroy() { +void ChatLLM::destroy() +{ m_stopGenerating = true; m_llmThread.quit(); m_llmThread.wait(); @@ -116,11 +125,15 @@ void ChatLLM::destroy() { // The only time we should have a model loaded here is on shutdown // as we explicitly unload the model in all other circumstances if (isModelLoaded()) { - delete m_llModelInfo.model; - m_llModelInfo.model = nullptr; + m_llModelInfo.model.reset(); } } +void ChatLLM::destroyStore() +{ + LLModelStore::globalInstance()->destroy(); +} + void ChatLLM::handleThreadStarted() { m_timer = new TokenTimer(this); @@ -161,7 +174,7 @@ bool ChatLLM::loadDefaultModel() return loadModel(defaultModel); } -bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) +void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) { // We're trying to see if the store already has the model fully loaded that we wish to use // and if so we just acquire it from the store and switch the context and return true. If the @@ -169,10 +182,11 @@ bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) // If we're already loaded or a server or we're reloading to change the variant/device or the // modelInfo is empty, then this should fail - if (isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty()) { - m_shouldTrySwitchContext = false; - emit trySwitchContextOfLoadedModelCompleted(false); - return false; + if ( + isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty() || !m_shouldBeLoaded + ) { + emit trySwitchContextOfLoadedModelCompleted(0); + return; } QString filePath = modelInfo.dirpath + modelInfo.filename(); @@ -180,33 +194,28 @@ bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) m_llModelInfo = LLModelStore::globalInstance()->acquireModel(); #if defined(DEBUG_MODEL_LOADING) - qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model; + qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif // The store gave us no already loaded model, the wrong type of model, then give it back to the // store and fail - if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo) { - LLModelStore::globalInstance()->releaseModel(m_llModelInfo); - m_llModelInfo = LLModelInfo(); - m_shouldTrySwitchContext = false; - emit trySwitchContextOfLoadedModelCompleted(false); - return false; + if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo || !m_shouldBeLoaded) { + LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); + emit trySwitchContextOfLoadedModelCompleted(0); + return; } #if defined(DEBUG_MODEL_LOADING) - qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model; + qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif - // We should be loaded and now we are - m_shouldBeLoaded = true; - m_shouldTrySwitchContext = false; + emit trySwitchContextOfLoadedModelCompleted(2); // Restore, signal and process restoreState(); emit modelLoadingPercentageChanged(1.0f); - emit trySwitchContextOfLoadedModelCompleted(true); + emit trySwitchContextOfLoadedModelCompleted(0); processSystemPrompt(); - return true; } bool ChatLLM::loadModel(const ModelInfo &modelInfo) @@ -223,6 +232,13 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) if (isModelLoaded() && this->modelInfo() == modelInfo) return true; + // reset status + emit modelLoadingPercentageChanged(std::numeric_limits::min()); // small non-zero positive value + emit modelLoadingError(""); + emit reportFallbackReason(""); + emit reportDevice(""); + m_pristineLoadedState = false; + QString filePath = modelInfo.dirpath + modelInfo.filename(); QFileInfo fileInfo(filePath); @@ -231,28 +247,25 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) if (alreadyAcquired) { resetContext(); #if defined(DEBUG_MODEL_LOADING) - qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model; + qDebug() << "already acquired model deleted" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif - delete m_llModelInfo.model; - m_llModelInfo.model = nullptr; - emit modelLoadingPercentageChanged(std::numeric_limits::min()); // small non-zero positive value + m_llModelInfo.model.reset(); } else if (!m_isServer) { // This is a blocking call that tries to retrieve the model we need from the model store. // If it succeeds, then we just have to restore state. If the store has never had a model // returned to it, then the modelInfo.model pointer should be null which will happen on startup m_llModelInfo = LLModelStore::globalInstance()->acquireModel(); #if defined(DEBUG_MODEL_LOADING) - qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model; + qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif // At this point it is possible that while we were blocked waiting to acquire the model from the // store, that our state was changed to not be loaded. If this is the case, release the model // back into the store and quit loading if (!m_shouldBeLoaded) { #if defined(DEBUG_MODEL_LOADING) - qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model; + qDebug() << "no longer need model" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif - LLModelStore::globalInstance()->releaseModel(m_llModelInfo); - m_llModelInfo = LLModelInfo(); + LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); emit modelLoadingPercentageChanged(0.0f); return false; } @@ -260,7 +273,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) // Check if the store just gave us exactly the model we were looking for if (m_llModelInfo.model && m_llModelInfo.fileInfo == fileInfo && !m_reloadingToChangeVariant) { #if defined(DEBUG_MODEL_LOADING) - qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model; + qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif restoreState(); emit modelLoadingPercentageChanged(1.0f); @@ -274,10 +287,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) } else { // Release the memory since we have to switch to a different model. #if defined(DEBUG_MODEL_LOADING) - qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model; + qDebug() << "deleting model" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif - delete m_llModelInfo.model; - m_llModelInfo.model = nullptr; + m_llModelInfo.model.reset(); } } @@ -307,7 +319,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) model->setModelName(modelName); model->setRequestURL(modelInfo.url()); model->setAPIKey(apiKey); - m_llModelInfo.model = model; + m_llModelInfo.model.reset(model); } else { QElapsedTimer modelLoadTimer; modelLoadTimer.start(); @@ -322,9 +334,10 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) buildVariant = "metal"; #endif QString constructError; - m_llModelInfo.model = nullptr; + m_llModelInfo.model.reset(); try { - m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx); + auto *model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx); + m_llModelInfo.model.reset(model); } catch (const LLModel::MissingImplementationError &e) { modelLoadProps.insert("error", "missing_model_impl"); constructError = e.what(); @@ -355,8 +368,6 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) return m_shouldBeLoaded; }); - emit reportFallbackReason(""); // no fallback yet - auto approxDeviceMemGB = [](const LLModel::GPUDevice *dev) { float memGB = dev->heapSize / float(1024 * 1024 * 1024); return std::floor(memGB * 10.f) / 10.f; // truncate to 1 decimal place @@ -407,6 +418,16 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) emit reportDevice(actualDevice); bool success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, ngl); + + if (!m_shouldBeLoaded) { + m_llModelInfo.model.reset(); + if (!m_isServer) + LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); + m_llModelInfo = LLModelInfo(); + emit modelLoadingPercentageChanged(0.0f); + return false; + } + if (actualDevice == "CPU") { // we asked llama.cpp to use the CPU } else if (!success) { @@ -415,6 +436,15 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) emit reportFallbackReason("
GPU loading failed (out of VRAM?)"); modelLoadProps.insert("cpu_fallback_reason", "gpu_load_failed"); success = m_llModelInfo.model->loadModel(filePath.toStdString(), n_ctx, 0); + + if (!m_shouldBeLoaded) { + m_llModelInfo.model.reset(); + if (!m_isServer) + LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); + m_llModelInfo = LLModelInfo(); + emit modelLoadingPercentageChanged(0.0f); + return false; + } } else if (!m_llModelInfo.model->usingGPUDevice()) { // ggml_vk_init was not called in llama.cpp // We might have had to fallback to CPU after load if the model is not possible to accelerate @@ -425,10 +455,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) } if (!success) { - delete m_llModelInfo.model; - m_llModelInfo.model = nullptr; + m_llModelInfo.model.reset(); if (!m_isServer) - LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store + LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); m_llModelInfo = LLModelInfo(); emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename())); modelLoadProps.insert("error", "loadmodel_failed"); @@ -438,10 +467,9 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) case 'G': m_llModelType = LLModelType::GPTJ_; break; default: { - delete m_llModelInfo.model; - m_llModelInfo.model = nullptr; + m_llModelInfo.model.reset(); if (!m_isServer) - LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store + LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); m_llModelInfo = LLModelInfo(); emit modelLoadingError(QString("Could not determine model type for %1").arg(modelInfo.filename())); } @@ -451,13 +479,13 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) } } else { if (!m_isServer) - LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store + LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); m_llModelInfo = LLModelInfo(); emit modelLoadingError(QString("Error loading %1: %2").arg(modelInfo.filename()).arg(constructError)); } } #if defined(DEBUG_MODEL_LOADING) - qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model; + qDebug() << "new model" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif restoreState(); #if defined(DEBUG) @@ -471,7 +499,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) Network::globalInstance()->trackChatEvent("model_load", modelLoadProps); } else { if (!m_isServer) - LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store + LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); // release back into the store m_llModelInfo = LLModelInfo(); emit modelLoadingError(QString("Could not find file for model %1").arg(modelInfo.filename())); } @@ -480,7 +508,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo) setModelInfo(modelInfo); processSystemPrompt(); } - return m_llModelInfo.model; + return bool(m_llModelInfo.model); } bool ChatLLM::isModelLoaded() const @@ -700,22 +728,23 @@ bool ChatLLM::promptInternal(const QList &collectionList, const QString emit responseChanged(QString::fromStdString(m_response)); } emit responseStopped(elapsed); + m_pristineLoadedState = false; return true; } void ChatLLM::setShouldBeLoaded(bool b) { #if defined(DEBUG_MODEL_LOADING) - qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model; + qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get(); #endif m_shouldBeLoaded = b; // atomic emit shouldBeLoadedChanged(); } -void ChatLLM::setShouldTrySwitchContext(bool b) +void ChatLLM::requestTrySwitchContext() { - m_shouldTrySwitchContext = b; // atomic - emit shouldTrySwitchContextChanged(); + m_shouldBeLoaded = true; // atomic + emit trySwitchContextRequested(modelInfo()); } void ChatLLM::handleShouldBeLoadedChanged() @@ -726,12 +755,6 @@ void ChatLLM::handleShouldBeLoadedChanged() unloadModel(); } -void ChatLLM::handleShouldTrySwitchContextChanged() -{ - if (m_shouldTrySwitchContext) - trySwitchContextOfLoadedModel(modelInfo()); -} - void ChatLLM::unloadModel() { if (!isModelLoaded() || m_isServer) @@ -746,17 +769,16 @@ void ChatLLM::unloadModel() saveState(); #if defined(DEBUG_MODEL_LOADING) - qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model; + qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif if (m_forceUnloadModel) { - delete m_llModelInfo.model; - m_llModelInfo.model = nullptr; + m_llModelInfo.model.reset(); m_forceUnloadModel = false; } - LLModelStore::globalInstance()->releaseModel(m_llModelInfo); - m_llModelInfo = LLModelInfo(); + LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); + m_pristineLoadedState = false; } void ChatLLM::reloadModel() @@ -768,7 +790,7 @@ void ChatLLM::reloadModel() return; #if defined(DEBUG_MODEL_LOADING) - qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model; + qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif const ModelInfo m = modelInfo(); if (m.name().isEmpty()) @@ -795,6 +817,7 @@ void ChatLLM::generateName() m_nameResponse = trimmed; emit generatedNameChanged(QString::fromStdString(m_nameResponse)); } + m_pristineLoadedState = false; } void ChatLLM::handleChatIdChanged(const QString &id) @@ -934,7 +957,10 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, // If we do not deserialize the KV or it is discarded, then we need to restore the state from the // text only. This will be a costly operation, but the chat has to be restored from the text archive // alone. - m_restoreStateFromText = !deserializeKV || discardKV; + if (!deserializeKV || discardKV) { + m_restoreStateFromText = true; + m_pristineLoadedState = true; + } if (!deserializeKV) { #if defined(DEBUG) @@ -998,14 +1024,14 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, void ChatLLM::saveState() { - if (!isModelLoaded()) + if (!isModelLoaded() || m_pristineLoadedState) return; if (m_llModelType == LLModelType::API_) { m_state.clear(); QDataStream stream(&m_state, QIODeviceBase::WriteOnly); stream.setVersion(QDataStream::Qt_6_4); - ChatAPI *chatAPI = static_cast(m_llModelInfo.model); + ChatAPI *chatAPI = static_cast(m_llModelInfo.model.get()); stream << chatAPI->context(); return; } @@ -1026,7 +1052,7 @@ void ChatLLM::restoreState() if (m_llModelType == LLModelType::API_) { QDataStream stream(&m_state, QIODeviceBase::ReadOnly); stream.setVersion(QDataStream::Qt_6_4); - ChatAPI *chatAPI = static_cast(m_llModelInfo.model); + ChatAPI *chatAPI = static_cast(m_llModelInfo.model.get()); QList context; stream >> context; chatAPI->setContext(context); @@ -1045,13 +1071,18 @@ void ChatLLM::restoreState() if (m_llModelInfo.model->stateSize() == m_state.size()) { m_llModelInfo.model->restoreState(static_cast(reinterpret_cast(m_state.data()))); m_processedSystemPrompt = true; + m_pristineLoadedState = true; } else { qWarning() << "restoring state from text because" << m_llModelInfo.model->stateSize() << "!=" << m_state.size(); m_restoreStateFromText = true; } - m_state.clear(); - m_state.squeeze(); + // free local state copy unless unload is pending + if (m_shouldBeLoaded) { + m_state.clear(); + m_state.squeeze(); + m_pristineLoadedState = false; + } } void ChatLLM::processSystemPrompt() @@ -1105,6 +1136,7 @@ void ChatLLM::processSystemPrompt() #endif m_processedSystemPrompt = m_stopGenerating == false; + m_pristineLoadedState = false; } void ChatLLM::processRestoreStateFromText() @@ -1163,4 +1195,6 @@ void ChatLLM::processRestoreStateFromText() m_isRecalc = false; emit recalcChanged(); + + m_pristineLoadedState = false; } diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 7ef6e96f..fabb0164 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -5,6 +5,8 @@ #include #include +#include + #include "database.h" #include "modellist.h" #include "../gpt4all-backend/llmodel.h" @@ -16,7 +18,7 @@ enum LLModelType { }; struct LLModelInfo { - LLModel *model = nullptr; + std::unique_ptr model; QFileInfo fileInfo; // NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which // must be able to serialize the information even if it is in the unloaded state @@ -72,6 +74,7 @@ public: virtual ~ChatLLM(); void destroy(); + static void destroyStore(); bool isModelLoaded() const; void regenerateResponse(); void resetResponse(); @@ -81,7 +84,7 @@ public: bool shouldBeLoaded() const { return m_shouldBeLoaded; } void setShouldBeLoaded(bool b); - void setShouldTrySwitchContext(bool b); + void requestTrySwitchContext(); void setForceUnloadModel(bool b) { m_forceUnloadModel = b; } void setMarkedForDeletion(bool b) { m_markedForDeletion = b; } @@ -101,7 +104,7 @@ public: public Q_SLOTS: bool prompt(const QList &collectionList, const QString &prompt); bool loadDefaultModel(); - bool trySwitchContextOfLoadedModel(const ModelInfo &modelInfo); + void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo); bool loadModel(const ModelInfo &modelInfo); void modelChangeRequested(const ModelInfo &modelInfo); void unloadModel(); @@ -109,7 +112,6 @@ public Q_SLOTS: void generateName(); void handleChatIdChanged(const QString &id); void handleShouldBeLoadedChanged(); - void handleShouldTrySwitchContextChanged(); void handleThreadStarted(); void handleForceMetalChanged(bool forceMetal); void handleDeviceChanged(); @@ -128,8 +130,8 @@ Q_SIGNALS: void stateChanged(); void threadStarted(); void shouldBeLoadedChanged(); - void shouldTrySwitchContextChanged(); - void trySwitchContextOfLoadedModelCompleted(bool); + void trySwitchContextRequested(const ModelInfo &modelInfo); + void trySwitchContextOfLoadedModelCompleted(int value); void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); void reportSpeed(const QString &speed); void reportDevice(const QString &device); @@ -172,7 +174,6 @@ private: QThread m_llmThread; std::atomic m_stopGenerating; std::atomic m_shouldBeLoaded; - std::atomic m_shouldTrySwitchContext; std::atomic m_isRecalc; std::atomic m_forceUnloadModel; std::atomic m_markedForDeletion; @@ -181,6 +182,10 @@ private: bool m_reloadingToChangeVariant; bool m_processedSystemPrompt; bool m_restoreStateFromText; + // m_pristineLoadedState is set if saveSate is unnecessary, either because: + // - an unload was queued during LLModel::restoreState() + // - the chat will be restored from text and hasn't been interacted with yet + bool m_pristineLoadedState = false; QVector> m_stateFromText; }; diff --git a/gpt4all-chat/qml/ChatView.qml b/gpt4all-chat/qml/ChatView.qml index 0d6eca81..d72e651e 100644 --- a/gpt4all-chat/qml/ChatView.qml +++ b/gpt4all-chat/qml/ChatView.qml @@ -122,10 +122,6 @@ Rectangle { return ModelList.modelInfo(currentChat.modelInfo.id).name; } - property bool isCurrentlyLoading: false - property real modelLoadingPercentage: 0.0 - property bool trySwitchContextInProgress: false - PopupDialog { id: errorCompatHardware anchors.centerIn: parent @@ -340,34 +336,18 @@ Rectangle { implicitWidth: 575 width: window.width >= 750 ? implicitWidth : implicitWidth - (750 - window.width) enabled: !currentChat.isServer - && !window.trySwitchContextInProgress - && !window.isCurrentlyLoading + && !currentChat.trySwitchContextInProgress + && !currentChat.isCurrentlyLoading model: ModelList.installedModels valueRole: "id" textRole: "name" function changeModel(index) { - window.modelLoadingPercentage = 0.0; - window.isCurrentlyLoading = true; currentChat.stopGenerating() currentChat.reset(); currentChat.modelInfo = ModelList.modelInfo(comboBox.valueAt(index)) } - Connections { - target: currentChat - function onModelLoadingPercentageChanged() { - window.modelLoadingPercentage = currentChat.modelLoadingPercentage; - window.isCurrentlyLoading = currentChat.modelLoadingPercentage !== 0.0 - && currentChat.modelLoadingPercentage !== 1.0; - } - function onTrySwitchContextOfLoadedModelAttempted() { - window.trySwitchContextInProgress = true; - } - function onTrySwitchContextOfLoadedModelCompleted() { - window.trySwitchContextInProgress = false; - } - } Connections { target: switchModelDialog function onAccepted() { @@ -377,14 +357,14 @@ Rectangle { background: ProgressBar { id: modelProgress - value: window.modelLoadingPercentage + value: currentChat.modelLoadingPercentage background: Rectangle { color: theme.mainComboBackground radius: 10 } contentItem: Item { Rectangle { - visible: window.isCurrentlyLoading + visible: currentChat.isCurrentlyLoading anchors.bottom: parent.bottom width: modelProgress.visualPosition * parent.width height: 10 @@ -406,13 +386,15 @@ Rectangle { text: { if (currentChat.modelLoadingError !== "") return qsTr("Model loading error...") - if (window.trySwitchContextInProgress) + if (currentChat.trySwitchContextInProgress == 1) + return qsTr("Waiting for model...") + if (currentChat.trySwitchContextInProgress == 2) return qsTr("Switching context...") if (currentModelName() === "") return qsTr("Choose a model...") if (currentChat.modelLoadingPercentage === 0.0) return qsTr("Reload \u00B7 ") + currentModelName() - if (window.isCurrentlyLoading) + if (currentChat.isCurrentlyLoading) return qsTr("Loading \u00B7 ") + currentModelName() return currentModelName() } @@ -456,7 +438,7 @@ Rectangle { MyMiniButton { id: ejectButton - visible: currentChat.isModelLoaded && !window.isCurrentlyLoading + visible: currentChat.isModelLoaded && !currentChat.isCurrentlyLoading z: 500 anchors.right: parent.right anchors.rightMargin: 50 @@ -474,8 +456,8 @@ Rectangle { MyMiniButton { id: reloadButton visible: currentChat.modelLoadingError === "" - && !window.trySwitchContextInProgress - && !window.isCurrentlyLoading + && !currentChat.trySwitchContextInProgress + && !currentChat.isCurrentlyLoading && (currentChat.isModelLoaded || currentModelName() !== "") z: 500 anchors.right: ejectButton.visible ? ejectButton.left : parent.right @@ -1344,8 +1326,9 @@ Rectangle { textColor: theme.textColor visible: !currentChat.isServer && !currentChat.isModelLoaded - && !window.trySwitchContextInProgress - && !window.isCurrentlyLoading + && currentChat.modelLoadingError === "" + && !currentChat.trySwitchContextInProgress + && !currentChat.isCurrentlyLoading && currentModelName() !== "" Image {