From 39f5c5363843f9ed6619996b4c2336e66013911a Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Thu, 8 Aug 2024 15:38:34 -0400 Subject: [PATCH] create a generic interface for LlamaCppModel, called LLModel Signed-off-by: Jared Van Bortel --- gpt4all-chat/CMakeLists.txt | 1 + gpt4all-chat/chat.cpp | 58 +++++++++++++------------ gpt4all-chat/chat.h | 4 +- gpt4all-chat/llamacpp_model.cpp | 3 +- gpt4all-chat/llamacpp_model.h | 67 ++++++++++------------------- gpt4all-chat/llmodel.h | 76 +++++++++++++++++++++++++++++++++ 6 files changed, 134 insertions(+), 75 deletions(-) create mode 100644 gpt4all-chat/llmodel.h diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index e3c37d6b..3165077d 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -109,6 +109,7 @@ endif() qt_add_executable(chat main.cpp chat.h chat.cpp + llmodel.h llamacpp_model.h llamacpp_model.cpp chatmodel.h chatlistmodel.h chatlistmodel.cpp chatapi.h chatapi.cpp diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index bc4bbb6b..92e98d61 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -1,6 +1,7 @@ #include "chat.h" #include "chatlistmodel.h" +#include "llamacpp_model.h" #include "mysettings.h" #include "network.h" #include "server.h" @@ -55,31 +56,31 @@ Chat::~Chat() void Chat::connectLLM() { // Should be in different threads - connect(m_llmodel, &LlamaCppModel::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); - connect(m_llmodel, &LlamaCppModel::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); + connect(m_llmodel, &LLModel::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection); - connect(this, &Chat::promptRequested, m_llmodel, &LlamaCppModel::prompt, Qt::QueuedConnection); - connect(this, &Chat::modelChangeRequested, m_llmodel, &LlamaCppModel::modelChangeRequested, Qt::QueuedConnection); - connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &LlamaCppModel::loadDefaultModel, Qt::QueuedConnection); - connect(this, &Chat::loadModelRequested, m_llmodel, &LlamaCppModel::loadModel, Qt::QueuedConnection); - connect(this, &Chat::generateNameRequested, m_llmodel, &LlamaCppModel::generateName, Qt::QueuedConnection); - connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LlamaCppModel::regenerateResponse, Qt::QueuedConnection); - connect(this, &Chat::resetResponseRequested, m_llmodel, &LlamaCppModel::resetResponse, Qt::QueuedConnection); - connect(this, &Chat::resetContextRequested, m_llmodel, &LlamaCppModel::resetContext, Qt::QueuedConnection); - connect(this, &Chat::processSystemPromptRequested, m_llmodel, &LlamaCppModel::processSystemPrompt, Qt::QueuedConnection); + connect(this, &Chat::promptRequested, m_llmodel, &LLModel::prompt, Qt::QueuedConnection); + connect(this, &Chat::modelChangeRequested, m_llmodel, &LLModel::modelChangeRequested, Qt::QueuedConnection); + connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &LLModel::loadDefaultModel, Qt::QueuedConnection); + connect(this, &Chat::loadModelRequested, m_llmodel, &LLModel::loadModel, Qt::QueuedConnection); + connect(this, &Chat::generateNameRequested, m_llmodel, &LLModel::generateName, Qt::QueuedConnection); + connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LLModel::regenerateResponse, Qt::QueuedConnection); + connect(this, &Chat::resetResponseRequested, m_llmodel, &LLModel::resetResponse, Qt::QueuedConnection); + connect(this, &Chat::resetContextRequested, m_llmodel, &LLModel::resetContext, Qt::QueuedConnection); + connect(this, &Chat::processSystemPromptRequested, m_llmodel, &LLModel::processSystemPrompt, Qt::QueuedConnection); connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections); } @@ -344,17 +345,20 @@ void Chat::handleTokenSpeedChanged(const QString &tokenSpeed) QString Chat::deviceBackend() const { - return m_llmodel->deviceBackend(); + auto *llamacppmodel = dynamic_cast(m_llmodel); + return llamacppmodel ? llamacppmodel->deviceBackend() : QString(); } QString Chat::device() const { - return m_llmodel->device(); + auto *llamacppmodel = dynamic_cast(m_llmodel); + return llamacppmodel ? llamacppmodel->device() : QString(); } QString Chat::fallbackReason() const { - return m_llmodel->fallbackReason(); + auto *llamacppmodel = dynamic_cast(m_llmodel); + return llamacppmodel ? llamacppmodel->fallbackReason() : QString(); } void Chat::handleDatabaseResultsChanged(const QList &results) diff --git a/gpt4all-chat/chat.h b/gpt4all-chat/chat.h index f61a26c9..dd914334 100644 --- a/gpt4all-chat/chat.h +++ b/gpt4all-chat/chat.h @@ -1,9 +1,9 @@ #ifndef CHAT_H #define CHAT_H -#include "llamacpp_model.h" #include "chatmodel.h" #include "database.h" // IWYU pragma: keep +#include "llmodel.h" #include "localdocsmodel.h" // IWYU pragma: keep #include "modellist.h" @@ -191,7 +191,7 @@ private: bool m_responseInProgress = false; ResponseState m_responseState; qint64 m_creationDate; - LlamaCppModel *m_llmodel; + LLModel *m_llmodel; QList m_databaseResults; bool m_isServer = false; bool m_shouldDeleteLater = false; diff --git a/gpt4all-chat/llamacpp_model.cpp b/gpt4all-chat/llamacpp_model.cpp index bca81403..8691717b 100644 --- a/gpt4all-chat/llamacpp_model.cpp +++ b/gpt4all-chat/llamacpp_model.cpp @@ -101,8 +101,7 @@ void LLModelInfo::resetModel(LlamaCppModel *cllm, ModelBackend *model) { } LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer) - : QObject{nullptr} - , m_promptResponseTokens(0) + : m_promptResponseTokens(0) , m_promptTokens(0) , m_restoringFromText(false) , m_shouldBeLoaded(false) diff --git a/gpt4all-chat/llamacpp_model.h b/gpt4all-chat/llamacpp_model.h index e436044d..c2bcda61 100644 --- a/gpt4all-chat/llamacpp_model.h +++ b/gpt4all-chat/llamacpp_model.h @@ -1,6 +1,7 @@ #pragma once #include "database.h" // IWYU pragma: keep +#include "llmodel.h" #include "modellist.h" #include "../gpt4all-backend/llamacpp_backend.h" @@ -89,34 +90,33 @@ private: quint32 m_tokens; }; -class LlamaCppModel : public QObject +class LlamaCppModel : public LLModel { Q_OBJECT - Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged) Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged) public: LlamaCppModel(Chat *parent, bool isServer = false); - virtual ~LlamaCppModel(); + ~LlamaCppModel() override; - void destroy(); + void destroy() override; static void destroyStore(); - void regenerateResponse(); - void resetResponse(); - void resetContext(); + void regenerateResponse() override; + void resetResponse() override; + void resetContext() override; - void stopGenerating() { m_stopGenerating = true; } + void stopGenerating() override { m_stopGenerating = true; } - void setShouldBeLoaded(bool b); - void requestTrySwitchContext(); - void setForceUnloadModel(bool b) { m_forceUnloadModel = b; } - void setMarkedForDeletion(bool b) { m_markedForDeletion = b; } + void setShouldBeLoaded(bool b) override; + void requestTrySwitchContext() override; + void setForceUnloadModel(bool b) override { m_forceUnloadModel = b; } + void setMarkedForDeletion(bool b) override { m_markedForDeletion = b; } - void setModelInfo(const ModelInfo &info); + void setModelInfo(const ModelInfo &info) override; - bool restoringFromText() const { return m_restoringFromText; } + bool restoringFromText() const override { return m_restoringFromText; } QString deviceBackend() const { @@ -141,41 +141,20 @@ public: return m_llModelInfo.fallbackReason.value_or(u""_s); } - bool serialize(QDataStream &stream, int version, bool serializeKV); - bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV); - void setStateFromText(const QVector> &stateFromText) { m_stateFromText = stateFromText; } + bool serialize(QDataStream &stream, int version, bool serializeKV) override; + bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) override; + void setStateFromText(const QVector> &stateFromText) override { m_stateFromText = stateFromText; } public Q_SLOTS: - bool prompt(const QList &collectionList, const QString &prompt); - bool loadDefaultModel(); - bool loadModel(const ModelInfo &modelInfo); - void modelChangeRequested(const ModelInfo &modelInfo); - void generateName(); - void processSystemPrompt(); + bool prompt(const QList &collectionList, const QString &prompt) override; + bool loadDefaultModel() override; + bool loadModel(const ModelInfo &modelInfo) override; + void modelChangeRequested(const ModelInfo &modelInfo) override; + void generateName() override; + void processSystemPrompt() override; Q_SIGNALS: - void restoringFromTextChanged(); - void loadedModelInfoChanged(); - void modelLoadingPercentageChanged(float); - void modelLoadingError(const QString &error); - void modelLoadingWarning(const QString &warning); - void responseChanged(const QString &response); - void promptProcessing(); - void generatingQuestions(); - void responseStopped(qint64 promptResponseMs); - void generatedNameChanged(const QString &name); - void generatedQuestionFinished(const QString &generatedQuestion); - void stateChanged(); - void threadStarted(); void shouldBeLoadedChanged(); - void trySwitchContextRequested(const ModelInfo &modelInfo); - void trySwitchContextOfLoadedModelCompleted(int value); - void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); - void reportSpeed(const QString &speed); - void reportDevice(const QString &device); - void reportFallbackReason(const QString &fallbackReason); - void databaseResultsChanged(const QList &results); - void modelInfoChanged(const ModelInfo &modelInfo); protected: bool isModelLoaded() const; diff --git a/gpt4all-chat/llmodel.h b/gpt4all-chat/llmodel.h new file mode 100644 index 00000000..f3f00ec0 --- /dev/null +++ b/gpt4all-chat/llmodel.h @@ -0,0 +1,76 @@ +#pragma once + +#include "database.h" // IWYU pragma: keep +#include "modellist.h" // IWYU pragma: keep + +#include +#include +#include +#include +#include + +class Chat; +class QDataStream; + +class LLModel : public QObject +{ + Q_OBJECT + Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged) + +protected: + LLModel() = default; + +public: + virtual ~LLModel() = default; + + virtual void destroy() {} + virtual void regenerateResponse() = 0; + virtual void resetResponse() = 0; + virtual void resetContext() = 0; + + virtual void stopGenerating() = 0; + + virtual void setShouldBeLoaded(bool b) = 0; + virtual void requestTrySwitchContext() = 0; + virtual void setForceUnloadModel(bool b) = 0; + virtual void setMarkedForDeletion(bool b) = 0; + + virtual void setModelInfo(const ModelInfo &info) = 0; + + virtual bool restoringFromText() const = 0; + + virtual bool serialize(QDataStream &stream, int version, bool serializeKV) = 0; + virtual bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) = 0; + virtual void setStateFromText(const QVector> &stateFromText) = 0; + +public Q_SLOTS: + virtual bool prompt(const QList &collectionList, const QString &prompt) = 0; + virtual bool loadDefaultModel() = 0; + virtual bool loadModel(const ModelInfo &modelInfo) = 0; + virtual void modelChangeRequested(const ModelInfo &modelInfo) = 0; + virtual void generateName() = 0; + virtual void processSystemPrompt() = 0; + +Q_SIGNALS: + void restoringFromTextChanged(); + void loadedModelInfoChanged(); + void modelLoadingPercentageChanged(float loadingPercentage); + void modelLoadingError(const QString &error); + void modelLoadingWarning(const QString &warning); + void responseChanged(const QString &response); + void promptProcessing(); + void generatingQuestions(); + void responseStopped(qint64 promptResponseMs); + void generatedNameChanged(const QString &name); + void generatedQuestionFinished(const QString &generatedQuestion); + void stateChanged(); + void threadStarted(); + void trySwitchContextRequested(const ModelInfo &modelInfo); + void trySwitchContextOfLoadedModelCompleted(int value); + void requestRetrieveFromDB(const QList &collections, const QString &text, int retrievalSize, QList *results); + void reportSpeed(const QString &speed); + void reportDevice(const QString &device); + void reportFallbackReason(const QString &fallbackReason); + void databaseResultsChanged(const QList &results); + void modelInfoChanged(const ModelInfo &modelInfo); +};