create a generic interface for LlamaCppModel, called LLModel

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-08-08 15:38:34 -04:00
parent f2e5c931fe
commit 39f5c53638
6 changed files with 134 additions and 75 deletions

View File

@ -109,6 +109,7 @@ endif()
qt_add_executable(chat
main.cpp
chat.h chat.cpp
llmodel.h
llamacpp_model.h llamacpp_model.cpp
chatmodel.h chatlistmodel.h chatlistmodel.cpp
chatapi.h chatapi.cpp

View File

@ -1,6 +1,7 @@
#include "chat.h"
#include "chatlistmodel.h"
#include "llamacpp_model.h"
#include "mysettings.h"
#include "network.h"
#include "server.h"
@ -55,31 +56,31 @@ Chat::~Chat()
void Chat::connectLLM()
{
// Should be in different threads
connect(m_llmodel, &LlamaCppModel::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &LlamaCppModel::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &LLModel::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &LlamaCppModel::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelChangeRequested, m_llmodel, &LlamaCppModel::modelChangeRequested, Qt::QueuedConnection);
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &LlamaCppModel::loadDefaultModel, Qt::QueuedConnection);
connect(this, &Chat::loadModelRequested, m_llmodel, &LlamaCppModel::loadModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &LlamaCppModel::generateName, Qt::QueuedConnection);
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LlamaCppModel::regenerateResponse, Qt::QueuedConnection);
connect(this, &Chat::resetResponseRequested, m_llmodel, &LlamaCppModel::resetResponse, Qt::QueuedConnection);
connect(this, &Chat::resetContextRequested, m_llmodel, &LlamaCppModel::resetContext, Qt::QueuedConnection);
connect(this, &Chat::processSystemPromptRequested, m_llmodel, &LlamaCppModel::processSystemPrompt, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &LLModel::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelChangeRequested, m_llmodel, &LLModel::modelChangeRequested, Qt::QueuedConnection);
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &LLModel::loadDefaultModel, Qt::QueuedConnection);
connect(this, &Chat::loadModelRequested, m_llmodel, &LLModel::loadModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &LLModel::generateName, Qt::QueuedConnection);
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LLModel::regenerateResponse, Qt::QueuedConnection);
connect(this, &Chat::resetResponseRequested, m_llmodel, &LLModel::resetResponse, Qt::QueuedConnection);
connect(this, &Chat::resetContextRequested, m_llmodel, &LLModel::resetContext, Qt::QueuedConnection);
connect(this, &Chat::processSystemPromptRequested, m_llmodel, &LLModel::processSystemPrompt, Qt::QueuedConnection);
connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections);
}
@ -344,17 +345,20 @@ void Chat::handleTokenSpeedChanged(const QString &tokenSpeed)
QString Chat::deviceBackend() const
{
return m_llmodel->deviceBackend();
auto *llamacppmodel = dynamic_cast<LlamaCppModel *>(m_llmodel);
return llamacppmodel ? llamacppmodel->deviceBackend() : QString();
}
QString Chat::device() const
{
return m_llmodel->device();
auto *llamacppmodel = dynamic_cast<LlamaCppModel *>(m_llmodel);
return llamacppmodel ? llamacppmodel->device() : QString();
}
QString Chat::fallbackReason() const
{
return m_llmodel->fallbackReason();
auto *llamacppmodel = dynamic_cast<LlamaCppModel *>(m_llmodel);
return llamacppmodel ? llamacppmodel->fallbackReason() : QString();
}
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)

View File

@ -1,9 +1,9 @@
#ifndef CHAT_H
#define CHAT_H
#include "llamacpp_model.h"
#include "chatmodel.h"
#include "database.h" // IWYU pragma: keep
#include "llmodel.h"
#include "localdocsmodel.h" // IWYU pragma: keep
#include "modellist.h"
@ -191,7 +191,7 @@ private:
bool m_responseInProgress = false;
ResponseState m_responseState;
qint64 m_creationDate;
LlamaCppModel *m_llmodel;
LLModel *m_llmodel;
QList<ResultInfo> m_databaseResults;
bool m_isServer = false;
bool m_shouldDeleteLater = false;

View File

@ -101,8 +101,7 @@ void LLModelInfo::resetModel(LlamaCppModel *cllm, ModelBackend *model) {
}
LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer)
: QObject{nullptr}
, m_promptResponseTokens(0)
: m_promptResponseTokens(0)
, m_promptTokens(0)
, m_restoringFromText(false)
, m_shouldBeLoaded(false)

View File

@ -1,6 +1,7 @@
#pragma once
#include "database.h" // IWYU pragma: keep
#include "llmodel.h"
#include "modellist.h"
#include "../gpt4all-backend/llamacpp_backend.h"
@ -89,34 +90,33 @@ private:
quint32 m_tokens;
};
class LlamaCppModel : public QObject
class LlamaCppModel : public LLModel
{
Q_OBJECT
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
public:
LlamaCppModel(Chat *parent, bool isServer = false);
virtual ~LlamaCppModel();
~LlamaCppModel() override;
void destroy();
void destroy() override;
static void destroyStore();
void regenerateResponse();
void resetResponse();
void resetContext();
void regenerateResponse() override;
void resetResponse() override;
void resetContext() override;
void stopGenerating() { m_stopGenerating = true; }
void stopGenerating() override { m_stopGenerating = true; }
void setShouldBeLoaded(bool b);
void requestTrySwitchContext();
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
void setShouldBeLoaded(bool b) override;
void requestTrySwitchContext() override;
void setForceUnloadModel(bool b) override { m_forceUnloadModel = b; }
void setMarkedForDeletion(bool b) override { m_markedForDeletion = b; }
void setModelInfo(const ModelInfo &info);
void setModelInfo(const ModelInfo &info) override;
bool restoringFromText() const { return m_restoringFromText; }
bool restoringFromText() const override { return m_restoringFromText; }
QString deviceBackend() const
{
@ -141,41 +141,20 @@ public:
return m_llModelInfo.fallbackReason.value_or(u""_s);
}
bool serialize(QDataStream &stream, int version, bool serializeKV);
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) { m_stateFromText = stateFromText; }
bool serialize(QDataStream &stream, int version, bool serializeKV) override;
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) override;
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) override { m_stateFromText = stateFromText; }
public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt);
bool loadDefaultModel();
bool loadModel(const ModelInfo &modelInfo);
void modelChangeRequested(const ModelInfo &modelInfo);
void generateName();
void processSystemPrompt();
bool prompt(const QList<QString> &collectionList, const QString &prompt) override;
bool loadDefaultModel() override;
bool loadModel(const ModelInfo &modelInfo) override;
void modelChangeRequested(const ModelInfo &modelInfo) override;
void generateName() override;
void processSystemPrompt() override;
Q_SIGNALS:
void restoringFromTextChanged();
void loadedModelInfoChanged();
void modelLoadingPercentageChanged(float);
void modelLoadingError(const QString &error);
void modelLoadingWarning(const QString &warning);
void responseChanged(const QString &response);
void promptProcessing();
void generatingQuestions();
void responseStopped(qint64 promptResponseMs);
void generatedNameChanged(const QString &name);
void generatedQuestionFinished(const QString &generatedQuestion);
void stateChanged();
void threadStarted();
void shouldBeLoadedChanged();
void trySwitchContextRequested(const ModelInfo &modelInfo);
void trySwitchContextOfLoadedModelCompleted(int value);
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
void reportSpeed(const QString &speed);
void reportDevice(const QString &device);
void reportFallbackReason(const QString &fallbackReason);
void databaseResultsChanged(const QList<ResultInfo> &results);
void modelInfoChanged(const ModelInfo &modelInfo);
protected:
bool isModelLoaded() const;

76
gpt4all-chat/llmodel.h Normal file
View File

@ -0,0 +1,76 @@
#pragma once
#include "database.h" // IWYU pragma: keep
#include "modellist.h" // IWYU pragma: keep
#include <QList>
#include <QObject>
#include <QPair>
#include <QString>
#include <QVector>
class Chat;
class QDataStream;
class LLModel : public QObject
{
Q_OBJECT
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
protected:
LLModel() = default;
public:
virtual ~LLModel() = default;
virtual void destroy() {}
virtual void regenerateResponse() = 0;
virtual void resetResponse() = 0;
virtual void resetContext() = 0;
virtual void stopGenerating() = 0;
virtual void setShouldBeLoaded(bool b) = 0;
virtual void requestTrySwitchContext() = 0;
virtual void setForceUnloadModel(bool b) = 0;
virtual void setMarkedForDeletion(bool b) = 0;
virtual void setModelInfo(const ModelInfo &info) = 0;
virtual bool restoringFromText() const = 0;
virtual bool serialize(QDataStream &stream, int version, bool serializeKV) = 0;
virtual bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) = 0;
virtual void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) = 0;
public Q_SLOTS:
virtual bool prompt(const QList<QString> &collectionList, const QString &prompt) = 0;
virtual bool loadDefaultModel() = 0;
virtual bool loadModel(const ModelInfo &modelInfo) = 0;
virtual void modelChangeRequested(const ModelInfo &modelInfo) = 0;
virtual void generateName() = 0;
virtual void processSystemPrompt() = 0;
Q_SIGNALS:
void restoringFromTextChanged();
void loadedModelInfoChanged();
void modelLoadingPercentageChanged(float loadingPercentage);
void modelLoadingError(const QString &error);
void modelLoadingWarning(const QString &warning);
void responseChanged(const QString &response);
void promptProcessing();
void generatingQuestions();
void responseStopped(qint64 promptResponseMs);
void generatedNameChanged(const QString &name);
void generatedQuestionFinished(const QString &generatedQuestion);
void stateChanged();
void threadStarted();
void trySwitchContextRequested(const ModelInfo &modelInfo);
void trySwitchContextOfLoadedModelCompleted(int value);
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
void reportSpeed(const QString &speed);
void reportDevice(const QString &device);
void reportFallbackReason(const QString &fallbackReason);
void databaseResultsChanged(const QList<ResultInfo> &results);
void modelInfoChanged(const ModelInfo &modelInfo);
};