mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-26 23:37:40 +00:00
create a generic interface for LlamaCppModel, called LLModel
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
f2e5c931fe
commit
39f5c53638
@ -109,6 +109,7 @@ endif()
|
||||
qt_add_executable(chat
|
||||
main.cpp
|
||||
chat.h chat.cpp
|
||||
llmodel.h
|
||||
llamacpp_model.h llamacpp_model.cpp
|
||||
chatmodel.h chatlistmodel.h chatlistmodel.cpp
|
||||
chatapi.h chatapi.cpp
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include "chat.h"
|
||||
|
||||
#include "chatlistmodel.h"
|
||||
#include "llamacpp_model.h"
|
||||
#include "mysettings.h"
|
||||
#include "network.h"
|
||||
#include "server.h"
|
||||
@ -55,31 +56,31 @@ Chat::~Chat()
|
||||
void Chat::connectLLM()
|
||||
{
|
||||
// Should be in different threads
|
||||
connect(m_llmodel, &LlamaCppModel::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LlamaCppModel::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
|
||||
connect(m_llmodel, &LLModel::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
|
||||
|
||||
connect(this, &Chat::promptRequested, m_llmodel, &LlamaCppModel::prompt, Qt::QueuedConnection);
|
||||
connect(this, &Chat::modelChangeRequested, m_llmodel, &LlamaCppModel::modelChangeRequested, Qt::QueuedConnection);
|
||||
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &LlamaCppModel::loadDefaultModel, Qt::QueuedConnection);
|
||||
connect(this, &Chat::loadModelRequested, m_llmodel, &LlamaCppModel::loadModel, Qt::QueuedConnection);
|
||||
connect(this, &Chat::generateNameRequested, m_llmodel, &LlamaCppModel::generateName, Qt::QueuedConnection);
|
||||
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LlamaCppModel::regenerateResponse, Qt::QueuedConnection);
|
||||
connect(this, &Chat::resetResponseRequested, m_llmodel, &LlamaCppModel::resetResponse, Qt::QueuedConnection);
|
||||
connect(this, &Chat::resetContextRequested, m_llmodel, &LlamaCppModel::resetContext, Qt::QueuedConnection);
|
||||
connect(this, &Chat::processSystemPromptRequested, m_llmodel, &LlamaCppModel::processSystemPrompt, Qt::QueuedConnection);
|
||||
connect(this, &Chat::promptRequested, m_llmodel, &LLModel::prompt, Qt::QueuedConnection);
|
||||
connect(this, &Chat::modelChangeRequested, m_llmodel, &LLModel::modelChangeRequested, Qt::QueuedConnection);
|
||||
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &LLModel::loadDefaultModel, Qt::QueuedConnection);
|
||||
connect(this, &Chat::loadModelRequested, m_llmodel, &LLModel::loadModel, Qt::QueuedConnection);
|
||||
connect(this, &Chat::generateNameRequested, m_llmodel, &LLModel::generateName, Qt::QueuedConnection);
|
||||
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LLModel::regenerateResponse, Qt::QueuedConnection);
|
||||
connect(this, &Chat::resetResponseRequested, m_llmodel, &LLModel::resetResponse, Qt::QueuedConnection);
|
||||
connect(this, &Chat::resetContextRequested, m_llmodel, &LLModel::resetContext, Qt::QueuedConnection);
|
||||
connect(this, &Chat::processSystemPromptRequested, m_llmodel, &LLModel::processSystemPrompt, Qt::QueuedConnection);
|
||||
|
||||
connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections);
|
||||
}
|
||||
@ -344,17 +345,20 @@ void Chat::handleTokenSpeedChanged(const QString &tokenSpeed)
|
||||
|
||||
QString Chat::deviceBackend() const
|
||||
{
|
||||
return m_llmodel->deviceBackend();
|
||||
auto *llamacppmodel = dynamic_cast<LlamaCppModel *>(m_llmodel);
|
||||
return llamacppmodel ? llamacppmodel->deviceBackend() : QString();
|
||||
}
|
||||
|
||||
QString Chat::device() const
|
||||
{
|
||||
return m_llmodel->device();
|
||||
auto *llamacppmodel = dynamic_cast<LlamaCppModel *>(m_llmodel);
|
||||
return llamacppmodel ? llamacppmodel->device() : QString();
|
||||
}
|
||||
|
||||
QString Chat::fallbackReason() const
|
||||
{
|
||||
return m_llmodel->fallbackReason();
|
||||
auto *llamacppmodel = dynamic_cast<LlamaCppModel *>(m_llmodel);
|
||||
return llamacppmodel ? llamacppmodel->fallbackReason() : QString();
|
||||
}
|
||||
|
||||
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
|
||||
|
@ -1,9 +1,9 @@
|
||||
#ifndef CHAT_H
|
||||
#define CHAT_H
|
||||
|
||||
#include "llamacpp_model.h"
|
||||
#include "chatmodel.h"
|
||||
#include "database.h" // IWYU pragma: keep
|
||||
#include "llmodel.h"
|
||||
#include "localdocsmodel.h" // IWYU pragma: keep
|
||||
#include "modellist.h"
|
||||
|
||||
@ -191,7 +191,7 @@ private:
|
||||
bool m_responseInProgress = false;
|
||||
ResponseState m_responseState;
|
||||
qint64 m_creationDate;
|
||||
LlamaCppModel *m_llmodel;
|
||||
LLModel *m_llmodel;
|
||||
QList<ResultInfo> m_databaseResults;
|
||||
bool m_isServer = false;
|
||||
bool m_shouldDeleteLater = false;
|
||||
|
@ -101,8 +101,7 @@ void LLModelInfo::resetModel(LlamaCppModel *cllm, ModelBackend *model) {
|
||||
}
|
||||
|
||||
LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer)
|
||||
: QObject{nullptr}
|
||||
, m_promptResponseTokens(0)
|
||||
: m_promptResponseTokens(0)
|
||||
, m_promptTokens(0)
|
||||
, m_restoringFromText(false)
|
||||
, m_shouldBeLoaded(false)
|
||||
|
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "database.h" // IWYU pragma: keep
|
||||
#include "llmodel.h"
|
||||
#include "modellist.h"
|
||||
|
||||
#include "../gpt4all-backend/llamacpp_backend.h"
|
||||
@ -89,34 +90,33 @@ private:
|
||||
quint32 m_tokens;
|
||||
};
|
||||
|
||||
class LlamaCppModel : public QObject
|
||||
class LlamaCppModel : public LLModel
|
||||
{
|
||||
Q_OBJECT
|
||||
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
|
||||
Q_PROPERTY(QString deviceBackend READ deviceBackend NOTIFY loadedModelInfoChanged)
|
||||
Q_PROPERTY(QString device READ device NOTIFY loadedModelInfoChanged)
|
||||
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
|
||||
|
||||
public:
|
||||
LlamaCppModel(Chat *parent, bool isServer = false);
|
||||
virtual ~LlamaCppModel();
|
||||
~LlamaCppModel() override;
|
||||
|
||||
void destroy();
|
||||
void destroy() override;
|
||||
static void destroyStore();
|
||||
void regenerateResponse();
|
||||
void resetResponse();
|
||||
void resetContext();
|
||||
void regenerateResponse() override;
|
||||
void resetResponse() override;
|
||||
void resetContext() override;
|
||||
|
||||
void stopGenerating() { m_stopGenerating = true; }
|
||||
void stopGenerating() override { m_stopGenerating = true; }
|
||||
|
||||
void setShouldBeLoaded(bool b);
|
||||
void requestTrySwitchContext();
|
||||
void setForceUnloadModel(bool b) { m_forceUnloadModel = b; }
|
||||
void setMarkedForDeletion(bool b) { m_markedForDeletion = b; }
|
||||
void setShouldBeLoaded(bool b) override;
|
||||
void requestTrySwitchContext() override;
|
||||
void setForceUnloadModel(bool b) override { m_forceUnloadModel = b; }
|
||||
void setMarkedForDeletion(bool b) override { m_markedForDeletion = b; }
|
||||
|
||||
void setModelInfo(const ModelInfo &info);
|
||||
void setModelInfo(const ModelInfo &info) override;
|
||||
|
||||
bool restoringFromText() const { return m_restoringFromText; }
|
||||
bool restoringFromText() const override { return m_restoringFromText; }
|
||||
|
||||
QString deviceBackend() const
|
||||
{
|
||||
@ -141,41 +141,20 @@ public:
|
||||
return m_llModelInfo.fallbackReason.value_or(u""_s);
|
||||
}
|
||||
|
||||
bool serialize(QDataStream &stream, int version, bool serializeKV);
|
||||
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
|
||||
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) { m_stateFromText = stateFromText; }
|
||||
bool serialize(QDataStream &stream, int version, bool serializeKV) override;
|
||||
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) override;
|
||||
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) override { m_stateFromText = stateFromText; }
|
||||
|
||||
public Q_SLOTS:
|
||||
bool prompt(const QList<QString> &collectionList, const QString &prompt);
|
||||
bool loadDefaultModel();
|
||||
bool loadModel(const ModelInfo &modelInfo);
|
||||
void modelChangeRequested(const ModelInfo &modelInfo);
|
||||
void generateName();
|
||||
void processSystemPrompt();
|
||||
bool prompt(const QList<QString> &collectionList, const QString &prompt) override;
|
||||
bool loadDefaultModel() override;
|
||||
bool loadModel(const ModelInfo &modelInfo) override;
|
||||
void modelChangeRequested(const ModelInfo &modelInfo) override;
|
||||
void generateName() override;
|
||||
void processSystemPrompt() override;
|
||||
|
||||
Q_SIGNALS:
|
||||
void restoringFromTextChanged();
|
||||
void loadedModelInfoChanged();
|
||||
void modelLoadingPercentageChanged(float);
|
||||
void modelLoadingError(const QString &error);
|
||||
void modelLoadingWarning(const QString &warning);
|
||||
void responseChanged(const QString &response);
|
||||
void promptProcessing();
|
||||
void generatingQuestions();
|
||||
void responseStopped(qint64 promptResponseMs);
|
||||
void generatedNameChanged(const QString &name);
|
||||
void generatedQuestionFinished(const QString &generatedQuestion);
|
||||
void stateChanged();
|
||||
void threadStarted();
|
||||
void shouldBeLoadedChanged();
|
||||
void trySwitchContextRequested(const ModelInfo &modelInfo);
|
||||
void trySwitchContextOfLoadedModelCompleted(int value);
|
||||
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
|
||||
void reportSpeed(const QString &speed);
|
||||
void reportDevice(const QString &device);
|
||||
void reportFallbackReason(const QString &fallbackReason);
|
||||
void databaseResultsChanged(const QList<ResultInfo> &results);
|
||||
void modelInfoChanged(const ModelInfo &modelInfo);
|
||||
|
||||
protected:
|
||||
bool isModelLoaded() const;
|
||||
|
76
gpt4all-chat/llmodel.h
Normal file
76
gpt4all-chat/llmodel.h
Normal file
@ -0,0 +1,76 @@
|
||||
#pragma once
|
||||
|
||||
#include "database.h" // IWYU pragma: keep
|
||||
#include "modellist.h" // IWYU pragma: keep
|
||||
|
||||
#include <QList>
|
||||
#include <QObject>
|
||||
#include <QPair>
|
||||
#include <QString>
|
||||
#include <QVector>
|
||||
|
||||
class Chat;
|
||||
class QDataStream;
|
||||
|
||||
class LLModel : public QObject
|
||||
{
|
||||
Q_OBJECT
|
||||
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
|
||||
|
||||
protected:
|
||||
LLModel() = default;
|
||||
|
||||
public:
|
||||
virtual ~LLModel() = default;
|
||||
|
||||
virtual void destroy() {}
|
||||
virtual void regenerateResponse() = 0;
|
||||
virtual void resetResponse() = 0;
|
||||
virtual void resetContext() = 0;
|
||||
|
||||
virtual void stopGenerating() = 0;
|
||||
|
||||
virtual void setShouldBeLoaded(bool b) = 0;
|
||||
virtual void requestTrySwitchContext() = 0;
|
||||
virtual void setForceUnloadModel(bool b) = 0;
|
||||
virtual void setMarkedForDeletion(bool b) = 0;
|
||||
|
||||
virtual void setModelInfo(const ModelInfo &info) = 0;
|
||||
|
||||
virtual bool restoringFromText() const = 0;
|
||||
|
||||
virtual bool serialize(QDataStream &stream, int version, bool serializeKV) = 0;
|
||||
virtual bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) = 0;
|
||||
virtual void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) = 0;
|
||||
|
||||
public Q_SLOTS:
|
||||
virtual bool prompt(const QList<QString> &collectionList, const QString &prompt) = 0;
|
||||
virtual bool loadDefaultModel() = 0;
|
||||
virtual bool loadModel(const ModelInfo &modelInfo) = 0;
|
||||
virtual void modelChangeRequested(const ModelInfo &modelInfo) = 0;
|
||||
virtual void generateName() = 0;
|
||||
virtual void processSystemPrompt() = 0;
|
||||
|
||||
Q_SIGNALS:
|
||||
void restoringFromTextChanged();
|
||||
void loadedModelInfoChanged();
|
||||
void modelLoadingPercentageChanged(float loadingPercentage);
|
||||
void modelLoadingError(const QString &error);
|
||||
void modelLoadingWarning(const QString &warning);
|
||||
void responseChanged(const QString &response);
|
||||
void promptProcessing();
|
||||
void generatingQuestions();
|
||||
void responseStopped(qint64 promptResponseMs);
|
||||
void generatedNameChanged(const QString &name);
|
||||
void generatedQuestionFinished(const QString &generatedQuestion);
|
||||
void stateChanged();
|
||||
void threadStarted();
|
||||
void trySwitchContextRequested(const ModelInfo &modelInfo);
|
||||
void trySwitchContextOfLoadedModelCompleted(int value);
|
||||
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
|
||||
void reportSpeed(const QString &speed);
|
||||
void reportDevice(const QString &device);
|
||||
void reportFallbackReason(const QString &fallbackReason);
|
||||
void databaseResultsChanged(const QList<ResultInfo> &results);
|
||||
void modelInfoChanged(const ModelInfo &modelInfo);
|
||||
};
|
Loading…
Reference in New Issue
Block a user