diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 92e98d61..4ef701db 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -74,7 +74,6 @@ void Chat::connectLLM() connect(this, &Chat::promptRequested, m_llmodel, &LLModel::prompt, Qt::QueuedConnection); connect(this, &Chat::modelChangeRequested, m_llmodel, &LLModel::modelChangeRequested, Qt::QueuedConnection); - connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &LLModel::loadDefaultModel, Qt::QueuedConnection); connect(this, &Chat::loadModelRequested, m_llmodel, &LLModel::loadModel, Qt::QueuedConnection); connect(this, &Chat::generateNameRequested, m_llmodel, &LLModel::generateName, Qt::QueuedConnection); connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LLModel::regenerateResponse, Qt::QueuedConnection); @@ -277,25 +276,23 @@ void Chat::markForDeletion() void Chat::unloadModel() { stopGenerating(); - m_llmodel->setShouldBeLoaded(false); + m_llmodel->releaseModelAsync(); } void Chat::reloadModel() { - m_llmodel->setShouldBeLoaded(true); + m_llmodel->loadModelAsync(); } void Chat::forceUnloadModel() { stopGenerating(); - m_llmodel->setForceUnloadModel(true); - m_llmodel->setShouldBeLoaded(false); + m_llmodel->releaseModelAsync(/*unload*/ true); } void Chat::forceReloadModel() { - m_llmodel->setForceUnloadModel(true); - m_llmodel->setShouldBeLoaded(true); + m_llmodel->loadModelAsync(/*reload*/ true); } void Chat::trySwitchContextOfLoadedModel() diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index dd914334..959004e5 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -145,7 +145,6 @@ Q_SIGNALS: void modelChangeRequested(const ModelInfo &modelInfo); void modelInfoChanged(); void restoringFromTextChanged(); - void loadDefaultModelRequested(); void loadModelRequested(const ModelInfo &modelInfo); void generateNameRequested(); void modelLoadingErrorChanged(); diff --git a/gpt4all-chat/llamacpp_model.cpp b/gpt4all-chat/llamacpp_model.cpp index d65ad445..b9cacd98 100644 --- a/gpt4all-chat/llamacpp_model.cpp +++ b/gpt4all-chat/llamacpp_model.cpp @@ -106,7 +106,6 @@ LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer) , m_promptTokens(0) , m_restoringFromText(false) , m_shouldBeLoaded(false) - , m_forceUnloadModel(false) , m_markedForDeletion(false) , m_stopGenerating(false) , m_timer(nullptr) @@ -117,8 +116,10 @@ LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer) , m_restoreStateFromText(false) { moveToThread(&m_llmThread); - connect(this, &LlamaCppModel::shouldBeLoadedChanged, this, &LlamaCppModel::handleShouldBeLoadedChanged, - Qt::QueuedConnection); // explicitly queued + connect( + this, &LlamaCppModel::requestLoadModel, this, &LlamaCppModel::loadModel + ); + connect(this, &LlamaCppModel::requestReleaseModel, this, &LlamaCppModel::releaseModel); connect(this, &LlamaCppModel::trySwitchContextRequested, this, &LlamaCppModel::trySwitchContextOfLoadedModel, Qt::QueuedConnection); // explicitly queued connect(parent, &Chat::idChanged, this, &LlamaCppModel::handleChatIdChanged); @@ -170,8 +171,7 @@ void LlamaCppModel::handleForceMetalChanged(bool forceMetal) m_forceMetal = forceMetal; if (isModelLoaded() && m_shouldBeLoaded) { m_reloadingToChangeVariant = true; - unloadModel(); - reloadModel(); + loadModel(/*reload*/ true); m_reloadingToChangeVariant = false; } #endif @@ -181,22 +181,11 @@ void LlamaCppModel::handleDeviceChanged() { if (isModelLoaded() && m_shouldBeLoaded) { m_reloadingToChangeVariant = true; - unloadModel(); - reloadModel(); + loadModel(/*reload*/ true); m_reloadingToChangeVariant = false; } } -bool LlamaCppModel::loadDefaultModel() -{ - ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo(); - if (defaultModel.filename().isEmpty()) { - emit modelLoadingError(u"Could not find any model to load"_qs); - return false; - } - return loadModel(defaultModel); -} - void LlamaCppModel::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) { // We're trying to see if the store already has the model fully loaded that we wish to use @@ -820,13 +809,16 @@ bool LlamaCppModel::promptInternal(const QList &collectionList, const Q return true; } -void LlamaCppModel::setShouldBeLoaded(bool b) +void LlamaCppModel::loadModelAsync(bool reload) { -#if defined(DEBUG_MODEL_LOADING) - qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get(); -#endif - m_shouldBeLoaded = b; // atomic - emit shouldBeLoadedChanged(); + m_shouldBeLoaded = true; // atomic + emit requestLoadModel(reload); +} + +void LlamaCppModel::releaseModelAsync(bool unload) +{ + m_shouldBeLoaded = false; // atomic + emit requestReleaseModel(unload); } void LlamaCppModel::requestTrySwitchContext() @@ -835,23 +827,43 @@ void LlamaCppModel::requestTrySwitchContext() emit trySwitchContextRequested(modelInfo()); } -void LlamaCppModel::handleShouldBeLoadedChanged() +void LlamaCppModel::loadModel(bool reload) { - if (m_shouldBeLoaded) - reloadModel(); - else - unloadModel(); + Q_ASSERT(m_shouldBeLoaded); + if (m_isServer) + return; // server managed models directly + + if (reload) + releaseModel(/*unload*/ true); + else if (isModelLoaded()) + return; // already loaded + +#if defined(DEBUG_MODEL_LOADING) + qDebug() << "loadModel" << m_llmThread.objectName() << m_llModelInfo.model.get(); +#endif + ModelInfo m = modelInfo(); + if (m.name().isEmpty()) { + ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo(); + if (defaultModel.filename().isEmpty()) { + emit modelLoadingError(u"Could not find any model to load"_s); + return; + } + m = defaultModel; + } + loadModel(m); } -void LlamaCppModel::unloadModel() +void LlamaCppModel::releaseModel(bool unload) { if (!isModelLoaded() || m_isServer) return; - if (!m_forceUnloadModel || !m_shouldBeLoaded) + if (unload && m_shouldBeLoaded) { + // reloading the model, don't show unloaded status + emit modelLoadingPercentageChanged(std::numeric_limits::min()); // small positive value + } else { emit modelLoadingPercentageChanged(0.0f); - else - emit modelLoadingPercentageChanged(std::numeric_limits::min()); // small non-zero positive value + } if (!m_markedForDeletion) saveState(); @@ -860,33 +872,14 @@ void LlamaCppModel::unloadModel() qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get(); #endif - if (m_forceUnloadModel) { + if (unload) { m_llModelInfo.resetModel(this); - m_forceUnloadModel = false; } LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); m_pristineLoadedState = false; } -void LlamaCppModel::reloadModel() -{ - if (isModelLoaded() && m_forceUnloadModel) - unloadModel(); // we unload first if we are forcing an unload - - if (isModelLoaded() || m_isServer) - return; - -#if defined(DEBUG_MODEL_LOADING) - qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get(); -#endif - const ModelInfo m = modelInfo(); - if (m.name().isEmpty()) - loadDefaultModel(); - else - loadModel(m); -} - void LlamaCppModel::generateName() { Q_ASSERT(isModelLoaded()); diff --git a/gpt4all-chat/llamacpp_model.h b/gpt4all-chat/llamacpp_model.h index c2bcda61..07a9c91b 100644 --- a/gpt4all-chat/llamacpp_model.h +++ b/gpt4all-chat/llamacpp_model.h @@ -109,9 +109,9 @@ public: void stopGenerating() override { m_stopGenerating = true; } - void setShouldBeLoaded(bool b) override; + void loadModelAsync(bool reload = false) override; + void releaseModelAsync(bool unload = false) override; void requestTrySwitchContext() override; - void setForceUnloadModel(bool b) override { m_forceUnloadModel = b; } void setMarkedForDeletion(bool b) override { m_markedForDeletion = b; } void setModelInfo(const ModelInfo &info) override; @@ -147,14 +147,14 @@ public: public Q_SLOTS: bool prompt(const QList &collectionList, const QString &prompt) override; - bool loadDefaultModel() override; bool loadModel(const ModelInfo &modelInfo) override; void modelChangeRequested(const ModelInfo &modelInfo) override; void generateName() override; void processSystemPrompt() override; Q_SIGNALS: - void shouldBeLoadedChanged(); + void requestLoadModel(bool reload); + void requestReleaseModel(bool unload); protected: bool isModelLoaded() const; @@ -183,11 +183,10 @@ protected: protected Q_SLOTS: void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo); - void unloadModel(); - void reloadModel(); + void loadModel(bool reload = false); + void releaseModel(bool unload = false); void generateQuestions(qint64 elapsed); void handleChatIdChanged(const QString &id); - void handleShouldBeLoadedChanged(); void handleThreadStarted(); void handleForceMetalChanged(bool forceMetal); void handleDeviceChanged(); @@ -197,11 +196,13 @@ private: bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); protected: - ModelBackend::PromptContext m_ctx; + // used by Server quint32 m_promptTokens; quint32 m_promptResponseTokens; + std::atomic m_shouldBeLoaded; private: + ModelBackend::PromptContext m_ctx; std::string m_response; std::string m_nameResponse; QString m_questionResponse; @@ -212,9 +213,7 @@ private: QByteArray m_state; QThread m_llmThread; std::atomic m_stopGenerating; - std::atomic m_shouldBeLoaded; std::atomic m_restoringFromText; // status indication - std::atomic m_forceUnloadModel; std::atomic m_markedForDeletion; bool m_isServer; bool m_forceMetal; diff --git a/gpt4all-chat/llmodel.h b/gpt4all-chat/llmodel.h index f3f00ec0..a1bb975b 100644 --- a/gpt4all-chat/llmodel.h +++ b/gpt4all-chat/llmodel.h @@ -30,9 +30,9 @@ public: virtual void stopGenerating() = 0; - virtual void setShouldBeLoaded(bool b) = 0; + virtual void loadModelAsync(bool reload = false) = 0; + virtual void releaseModelAsync(bool unload = false) = 0; virtual void requestTrySwitchContext() = 0; - virtual void setForceUnloadModel(bool b) = 0; virtual void setMarkedForDeletion(bool b) = 0; virtual void setModelInfo(const ModelInfo &info) = 0; @@ -45,7 +45,6 @@ public: public Q_SLOTS: virtual bool prompt(const QList &collectionList, const QString &prompt) = 0; - virtual bool loadDefaultModel() = 0; virtual bool loadModel(const ModelInfo &modelInfo) = 0; virtual void modelChangeRequested(const ModelInfo &modelInfo) = 0; virtual void generateName() = 0; diff --git a/gpt4all-chat/server.cpp b/gpt4all-chat/server.cpp index 34cb4018..dd498dbd 100644 --- a/gpt4all-chat/server.cpp +++ b/gpt4all-chat/server.cpp @@ -352,7 +352,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re emit requestServerNewPromptResponsePair(actualPrompt); // blocks // load the new model if necessary - setShouldBeLoaded(true); + m_shouldBeLoaded = true; if (modelInfo.filename().isEmpty()) { std::cerr << "ERROR: couldn't load default model " << modelRequested.toStdString() << std::endl;