rename ChatLLM to LlamaCppModel

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-08-08 15:14:58 -04:00
parent 429613ac32
commit f2e5c931fe
10 changed files with 105 additions and 105 deletions

View File

@ -109,7 +109,7 @@ endif()
qt_add_executable(chat qt_add_executable(chat
main.cpp main.cpp
chat.h chat.cpp chat.h chat.cpp
chatllm.h chatllm.cpp llamacpp_model.h llamacpp_model.cpp
chatmodel.h chatlistmodel.h chatlistmodel.cpp chatmodel.h chatlistmodel.h chatlistmodel.cpp
chatapi.h chatapi.cpp chatapi.h chatapi.cpp
chatviewtextprocessor.h chatviewtextprocessor.cpp chatviewtextprocessor.h chatviewtextprocessor.cpp

View File

@ -26,7 +26,7 @@ Chat::Chat(QObject *parent)
, m_chatModel(new ChatModel(this)) , m_chatModel(new ChatModel(this))
, m_responseState(Chat::ResponseStopped) , m_responseState(Chat::ResponseStopped)
, m_creationDate(QDateTime::currentSecsSinceEpoch()) , m_creationDate(QDateTime::currentSecsSinceEpoch())
, m_llmodel(new ChatLLM(this)) , m_llmodel(new LlamaCppModel(this))
, m_collectionModel(new LocalDocsCollectionsModel(this)) , m_collectionModel(new LocalDocsCollectionsModel(this))
{ {
connectLLM(); connectLLM();
@ -55,31 +55,31 @@ Chat::~Chat()
void Chat::connectLLM() void Chat::connectLLM()
{ {
// Should be in different threads // Should be in different threads
connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::restoringFromTextChanged, this, &Chat::handleRestoringFromText, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection); connect(m_llmodel, &LlamaCppModel::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);
connect(this, &Chat::promptRequested, m_llmodel, &ChatLLM::prompt, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &LlamaCppModel::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelChangeRequested, m_llmodel, &ChatLLM::modelChangeRequested, Qt::QueuedConnection); connect(this, &Chat::modelChangeRequested, m_llmodel, &LlamaCppModel::modelChangeRequested, Qt::QueuedConnection);
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection); connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &LlamaCppModel::loadDefaultModel, Qt::QueuedConnection);
connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection); connect(this, &Chat::loadModelRequested, m_llmodel, &LlamaCppModel::loadModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); connect(this, &Chat::generateNameRequested, m_llmodel, &LlamaCppModel::generateName, Qt::QueuedConnection);
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &ChatLLM::regenerateResponse, Qt::QueuedConnection); connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LlamaCppModel::regenerateResponse, Qt::QueuedConnection);
connect(this, &Chat::resetResponseRequested, m_llmodel, &ChatLLM::resetResponse, Qt::QueuedConnection); connect(this, &Chat::resetResponseRequested, m_llmodel, &LlamaCppModel::resetResponse, Qt::QueuedConnection);
connect(this, &Chat::resetContextRequested, m_llmodel, &ChatLLM::resetContext, Qt::QueuedConnection); connect(this, &Chat::resetContextRequested, m_llmodel, &LlamaCppModel::resetContext, Qt::QueuedConnection);
connect(this, &Chat::processSystemPromptRequested, m_llmodel, &ChatLLM::processSystemPrompt, Qt::QueuedConnection); connect(this, &Chat::processSystemPromptRequested, m_llmodel, &LlamaCppModel::processSystemPrompt, Qt::QueuedConnection);
connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections); connect(this, &Chat::collectionListChanged, m_collectionModel, &LocalDocsCollectionsModel::setCollections);
} }

View File

@ -1,7 +1,7 @@
#ifndef CHAT_H #ifndef CHAT_H
#define CHAT_H #define CHAT_H
#include "chatllm.h" #include "llamacpp_model.h"
#include "chatmodel.h" #include "chatmodel.h"
#include "database.h" // IWYU pragma: keep #include "database.h" // IWYU pragma: keep
#include "localdocsmodel.h" // IWYU pragma: keep #include "localdocsmodel.h" // IWYU pragma: keep
@ -191,7 +191,7 @@ private:
bool m_responseInProgress = false; bool m_responseInProgress = false;
ResponseState m_responseState; ResponseState m_responseState;
qint64 m_creationDate; qint64 m_creationDate;
ChatLLM *m_llmodel; LlamaCppModel *m_llmodel;
QList<ResultInfo> m_databaseResults; QList<ResultInfo> m_databaseResults;
bool m_isServer = false; bool m_isServer = false;
bool m_shouldDeleteLater = false; bool m_shouldDeleteLater = false;

View File

@ -47,7 +47,7 @@ bool ChatAPI::isModelLoaded() const
return true; return true;
} }
// All three of the state virtual functions are handled custom inside of chatllm save/restore // All three of the state virtual functions are handled custom inside of LlamaCppModel save/restore
size_t ChatAPI::stateSize() const size_t ChatAPI::stateSize() const
{ {
return 0; return 0;

View File

@ -2,7 +2,7 @@
#define CHATLISTMODEL_H #define CHATLISTMODEL_H
#include "chat.h" #include "chat.h"
#include "chatllm.h" #include "llamacpp_model.h"
#include "chatmodel.h" #include "chatmodel.h"
#include <QAbstractListModel> #include <QAbstractListModel>
@ -220,11 +220,11 @@ public:
int count() const { return m_chats.size(); } int count() const { return m_chats.size(); }
// stop ChatLLM threads for clean shutdown // stop LlamaCppModel threads for clean shutdown
void destroyChats() void destroyChats()
{ {
for (auto *chat: m_chats) { chat->destroy(); } for (auto *chat: m_chats) { chat->destroy(); }
ChatLLM::destroyStore(); LlamaCppModel::destroyStore();
} }
void removeChatFile(Chat *chat) const; void removeChatFile(Chat *chat) const;

View File

@ -1,4 +1,4 @@
#include "chatllm.h" #include "llamacpp_model.h"
#include "chat.h" #include "chat.h"
#include "chatapi.h" #include "chatapi.h"
@ -94,13 +94,13 @@ void LLModelStore::destroy()
m_availableModel.reset(); m_availableModel.reset();
} }
void LLModelInfo::resetModel(ChatLLM *cllm, ModelBackend *model) { void LLModelInfo::resetModel(LlamaCppModel *cllm, ModelBackend *model) {
this->model.reset(model); this->model.reset(model);
fallbackReason.reset(); fallbackReason.reset();
emit cllm->loadedModelInfoChanged(); emit cllm->loadedModelInfoChanged();
} }
ChatLLM::ChatLLM(Chat *parent, bool isServer) LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer)
: QObject{nullptr} : QObject{nullptr}
, m_promptResponseTokens(0) , m_promptResponseTokens(0)
, m_promptTokens(0) , m_promptTokens(0)
@ -117,29 +117,29 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_restoreStateFromText(false) , m_restoreStateFromText(false)
{ {
moveToThread(&m_llmThread); moveToThread(&m_llmThread);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, connect(this, &LlamaCppModel::shouldBeLoadedChanged, this, &LlamaCppModel::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued Qt::QueuedConnection); // explicitly queued
connect(this, &ChatLLM::trySwitchContextRequested, this, &ChatLLM::trySwitchContextOfLoadedModel, connect(this, &LlamaCppModel::trySwitchContextRequested, this, &LlamaCppModel::trySwitchContextOfLoadedModel,
Qt::QueuedConnection); // explicitly queued Qt::QueuedConnection); // explicitly queued
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); connect(parent, &Chat::idChanged, this, &LlamaCppModel::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted); connect(&m_llmThread, &QThread::started, this, &LlamaCppModel::handleThreadStarted);
connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged); connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &LlamaCppModel::handleForceMetalChanged);
connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &ChatLLM::handleDeviceChanged); connect(MySettings::globalInstance(), &MySettings::deviceChanged, this, &LlamaCppModel::handleDeviceChanged);
// The following are blocking operations and will block the llm thread // The following are blocking operations and will block the llm thread
connect(this, &ChatLLM::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB, connect(this, &LlamaCppModel::requestRetrieveFromDB, LocalDocs::globalInstance()->database(), &Database::retrieveFromDB,
Qt::BlockingQueuedConnection); Qt::BlockingQueuedConnection);
m_llmThread.setObjectName(parent->id()); m_llmThread.setObjectName(parent->id());
m_llmThread.start(); m_llmThread.start();
} }
ChatLLM::~ChatLLM() LlamaCppModel::~LlamaCppModel()
{ {
destroy(); destroy();
} }
void ChatLLM::destroy() void LlamaCppModel::destroy()
{ {
m_stopGenerating = true; m_stopGenerating = true;
m_llmThread.quit(); m_llmThread.quit();
@ -152,19 +152,19 @@ void ChatLLM::destroy()
} }
} }
void ChatLLM::destroyStore() void LlamaCppModel::destroyStore()
{ {
LLModelStore::globalInstance()->destroy(); LLModelStore::globalInstance()->destroy();
} }
void ChatLLM::handleThreadStarted() void LlamaCppModel::handleThreadStarted()
{ {
m_timer = new TokenTimer(this); m_timer = new TokenTimer(this);
connect(m_timer, &TokenTimer::report, this, &ChatLLM::reportSpeed); connect(m_timer, &TokenTimer::report, this, &LlamaCppModel::reportSpeed);
emit threadStarted(); emit threadStarted();
} }
void ChatLLM::handleForceMetalChanged(bool forceMetal) void LlamaCppModel::handleForceMetalChanged(bool forceMetal)
{ {
#if defined(Q_OS_MAC) && defined(__aarch64__) #if defined(Q_OS_MAC) && defined(__aarch64__)
m_forceMetal = forceMetal; m_forceMetal = forceMetal;
@ -177,7 +177,7 @@ void ChatLLM::handleForceMetalChanged(bool forceMetal)
#endif #endif
} }
void ChatLLM::handleDeviceChanged() void LlamaCppModel::handleDeviceChanged()
{ {
if (isModelLoaded() && m_shouldBeLoaded) { if (isModelLoaded() && m_shouldBeLoaded) {
m_reloadingToChangeVariant = true; m_reloadingToChangeVariant = true;
@ -187,7 +187,7 @@ void ChatLLM::handleDeviceChanged()
} }
} }
bool ChatLLM::loadDefaultModel() bool LlamaCppModel::loadDefaultModel()
{ {
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo(); ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
if (defaultModel.filename().isEmpty()) { if (defaultModel.filename().isEmpty()) {
@ -197,7 +197,7 @@ bool ChatLLM::loadDefaultModel()
return loadModel(defaultModel); return loadModel(defaultModel);
} }
void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) void LlamaCppModel::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
{ {
// We're trying to see if the store already has the model fully loaded that we wish to use // We're trying to see if the store already has the model fully loaded that we wish to use
// and if so we just acquire it from the store and switch the context and return true. If the // and if so we just acquire it from the store and switch the context and return true. If the
@ -241,7 +241,7 @@ void ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
processSystemPrompt(); processSystemPrompt();
} }
bool ChatLLM::loadModel(const ModelInfo &modelInfo) bool LlamaCppModel::loadModel(const ModelInfo &modelInfo)
{ {
// This is a complicated method because N different possible threads are interested in the outcome // This is a complicated method because N different possible threads are interested in the outcome
// of this method. Why? Because we have a main/gui thread trying to monitor the state of N different // of this method. Why? Because we have a main/gui thread trying to monitor the state of N different
@ -388,7 +388,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
/* Returns false if the model should no longer be loaded (!m_shouldBeLoaded). /* Returns false if the model should no longer be loaded (!m_shouldBeLoaded).
* Otherwise returns true, even on error. */ * Otherwise returns true, even on error. */
bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps) bool LlamaCppModel::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps)
{ {
QElapsedTimer modelLoadTimer; QElapsedTimer modelLoadTimer;
modelLoadTimer.start(); modelLoadTimer.start();
@ -585,7 +585,7 @@ bool ChatLLM::loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadPro
return true; return true;
}; };
bool ChatLLM::isModelLoaded() const bool LlamaCppModel::isModelLoaded() const
{ {
return m_llModelInfo.model && m_llModelInfo.model->isModelLoaded(); return m_llModelInfo.model && m_llModelInfo.model->isModelLoaded();
} }
@ -619,7 +619,7 @@ std::string trim_whitespace(const std::string& input)
} }
// FIXME(jared): we don't actually have to re-decode the prompt to generate a new response // FIXME(jared): we don't actually have to re-decode the prompt to generate a new response
void ChatLLM::regenerateResponse() void LlamaCppModel::regenerateResponse()
{ {
// ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning // ChatGPT uses a different semantic meaning for n_past than local models. For ChatGPT, the meaning
// of n_past is of the number of prompt/response pairs, rather than for total tokens. // of n_past is of the number of prompt/response pairs, rather than for total tokens.
@ -635,7 +635,7 @@ void ChatLLM::regenerateResponse()
emit responseChanged(QString::fromStdString(m_response)); emit responseChanged(QString::fromStdString(m_response));
} }
void ChatLLM::resetResponse() void LlamaCppModel::resetResponse()
{ {
m_promptTokens = 0; m_promptTokens = 0;
m_promptResponseTokens = 0; m_promptResponseTokens = 0;
@ -643,43 +643,43 @@ void ChatLLM::resetResponse()
emit responseChanged(QString::fromStdString(m_response)); emit responseChanged(QString::fromStdString(m_response));
} }
void ChatLLM::resetContext() void LlamaCppModel::resetContext()
{ {
resetResponse(); resetResponse();
m_processedSystemPrompt = false; m_processedSystemPrompt = false;
m_ctx = ModelBackend::PromptContext(); m_ctx = ModelBackend::PromptContext();
} }
QString ChatLLM::response() const QString LlamaCppModel::response() const
{ {
return QString::fromStdString(remove_leading_whitespace(m_response)); return QString::fromStdString(remove_leading_whitespace(m_response));
} }
void ChatLLM::setModelInfo(const ModelInfo &modelInfo) void LlamaCppModel::setModelInfo(const ModelInfo &modelInfo)
{ {
m_modelInfo = modelInfo; m_modelInfo = modelInfo;
emit modelInfoChanged(modelInfo); emit modelInfoChanged(modelInfo);
} }
void ChatLLM::acquireModel() void LlamaCppModel::acquireModel()
{ {
m_llModelInfo = LLModelStore::globalInstance()->acquireModel(); m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
emit loadedModelInfoChanged(); emit loadedModelInfoChanged();
} }
void ChatLLM::resetModel() void LlamaCppModel::resetModel()
{ {
m_llModelInfo = {}; m_llModelInfo = {};
emit loadedModelInfoChanged(); emit loadedModelInfoChanged();
} }
void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo) void LlamaCppModel::modelChangeRequested(const ModelInfo &modelInfo)
{ {
m_shouldBeLoaded = true; m_shouldBeLoaded = true;
loadModel(modelInfo); loadModel(modelInfo);
} }
bool ChatLLM::handlePrompt(int32_t token) bool LlamaCppModel::handlePrompt(int32_t token)
{ {
// m_promptResponseTokens is related to last prompt/response not // m_promptResponseTokens is related to last prompt/response not
// the entire context window which we can reset on regenerate prompt // the entire context window which we can reset on regenerate prompt
@ -692,7 +692,7 @@ bool ChatLLM::handlePrompt(int32_t token)
return !m_stopGenerating; return !m_stopGenerating;
} }
bool ChatLLM::handleResponse(int32_t token, const std::string &response) bool LlamaCppModel::handleResponse(int32_t token, const std::string &response)
{ {
#if defined(DEBUG) #if defined(DEBUG)
printf("%s", response.c_str()); printf("%s", response.c_str());
@ -716,7 +716,7 @@ bool ChatLLM::handleResponse(int32_t token, const std::string &response)
return !m_stopGenerating; return !m_stopGenerating;
} }
bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt) bool LlamaCppModel::prompt(const QList<QString> &collectionList, const QString &prompt)
{ {
if (m_restoreStateFromText) { if (m_restoreStateFromText) {
Q_ASSERT(m_state.isEmpty()); Q_ASSERT(m_state.isEmpty());
@ -738,7 +738,7 @@ bool ChatLLM::prompt(const QList<QString> &collectionList, const QString &prompt
repeat_penalty, repeat_penalty_tokens); repeat_penalty, repeat_penalty_tokens);
} }
bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate, bool LlamaCppModel::promptInternal(const QList<QString> &collectionList, const QString &prompt, const QString &promptTemplate,
int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty, int32_t n_predict, int32_t top_k, float top_p, float min_p, float temp, int32_t n_batch, float repeat_penalty,
int32_t repeat_penalty_tokens) int32_t repeat_penalty_tokens)
{ {
@ -766,8 +766,8 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
int n_threads = MySettings::globalInstance()->threadCount(); int n_threads = MySettings::globalInstance()->threadCount();
m_stopGenerating = false; m_stopGenerating = false;
auto promptFunc = std::bind(&ChatLLM::handlePrompt, this, std::placeholders::_1); auto promptFunc = std::bind(&LlamaCppModel::handlePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleResponse, this, std::placeholders::_1, auto responseFunc = std::bind(&LlamaCppModel::handleResponse, this, std::placeholders::_1,
std::placeholders::_2); std::placeholders::_2);
emit promptProcessing(); emit promptProcessing();
m_ctx.n_predict = n_predict; m_ctx.n_predict = n_predict;
@ -820,7 +820,7 @@ bool ChatLLM::promptInternal(const QList<QString> &collectionList, const QString
return true; return true;
} }
void ChatLLM::setShouldBeLoaded(bool b) void LlamaCppModel::setShouldBeLoaded(bool b)
{ {
#if defined(DEBUG_MODEL_LOADING) #if defined(DEBUG_MODEL_LOADING)
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get(); qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get();
@ -829,13 +829,13 @@ void ChatLLM::setShouldBeLoaded(bool b)
emit shouldBeLoadedChanged(); emit shouldBeLoadedChanged();
} }
void ChatLLM::requestTrySwitchContext() void LlamaCppModel::requestTrySwitchContext()
{ {
m_shouldBeLoaded = true; // atomic m_shouldBeLoaded = true; // atomic
emit trySwitchContextRequested(modelInfo()); emit trySwitchContextRequested(modelInfo());
} }
void ChatLLM::handleShouldBeLoadedChanged() void LlamaCppModel::handleShouldBeLoadedChanged()
{ {
if (m_shouldBeLoaded) if (m_shouldBeLoaded)
reloadModel(); reloadModel();
@ -843,7 +843,7 @@ void ChatLLM::handleShouldBeLoadedChanged()
unloadModel(); unloadModel();
} }
void ChatLLM::unloadModel() void LlamaCppModel::unloadModel()
{ {
if (!isModelLoaded() || m_isServer) if (!isModelLoaded() || m_isServer)
return; return;
@ -869,7 +869,7 @@ void ChatLLM::unloadModel()
m_pristineLoadedState = false; m_pristineLoadedState = false;
} }
void ChatLLM::reloadModel() void LlamaCppModel::reloadModel()
{ {
if (isModelLoaded() && m_forceUnloadModel) if (isModelLoaded() && m_forceUnloadModel)
unloadModel(); // we unload first if we are forcing an unload unloadModel(); // we unload first if we are forcing an unload
@ -887,7 +887,7 @@ void ChatLLM::reloadModel()
loadModel(m); loadModel(m);
} }
void ChatLLM::generateName() void LlamaCppModel::generateName()
{ {
Q_ASSERT(isModelLoaded()); Q_ASSERT(isModelLoaded());
if (!isModelLoaded()) if (!isModelLoaded())
@ -895,13 +895,13 @@ void ChatLLM::generateName()
const QString chatNamePrompt = MySettings::globalInstance()->modelChatNamePrompt(m_modelInfo); const QString chatNamePrompt = MySettings::globalInstance()->modelChatNamePrompt(m_modelInfo);
if (chatNamePrompt.trimmed().isEmpty()) { if (chatNamePrompt.trimmed().isEmpty()) {
qWarning() << "ChatLLM: not generating chat name because prompt is empty"; qWarning() << "LlamaCppModel: not generating chat name because prompt is empty";
return; return;
} }
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
auto promptFunc = std::bind(&ChatLLM::handleNamePrompt, this, std::placeholders::_1); auto promptFunc = std::bind(&LlamaCppModel::handleNamePrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2); auto responseFunc = std::bind(&LlamaCppModel::handleNameResponse, this, std::placeholders::_1, std::placeholders::_2);
ModelBackend::PromptContext ctx = m_ctx; ModelBackend::PromptContext ctx = m_ctx;
m_llModelInfo.model->prompt(chatNamePrompt.toStdString(), promptTemplate.toStdString(), m_llModelInfo.model->prompt(chatNamePrompt.toStdString(), promptTemplate.toStdString(),
promptFunc, responseFunc, /*allowContextShift*/ false, ctx); promptFunc, responseFunc, /*allowContextShift*/ false, ctx);
@ -913,12 +913,12 @@ void ChatLLM::generateName()
m_pristineLoadedState = false; m_pristineLoadedState = false;
} }
void ChatLLM::handleChatIdChanged(const QString &id) void LlamaCppModel::handleChatIdChanged(const QString &id)
{ {
m_llmThread.setObjectName(id); m_llmThread.setObjectName(id);
} }
bool ChatLLM::handleNamePrompt(int32_t token) bool LlamaCppModel::handleNamePrompt(int32_t token)
{ {
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "name prompt" << m_llmThread.objectName() << token; qDebug() << "name prompt" << m_llmThread.objectName() << token;
@ -927,7 +927,7 @@ bool ChatLLM::handleNamePrompt(int32_t token)
return !m_stopGenerating; return !m_stopGenerating;
} }
bool ChatLLM::handleNameResponse(int32_t token, const std::string &response) bool LlamaCppModel::handleNameResponse(int32_t token, const std::string &response)
{ {
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "name response" << m_llmThread.objectName() << token << response; qDebug() << "name response" << m_llmThread.objectName() << token << response;
@ -941,7 +941,7 @@ bool ChatLLM::handleNameResponse(int32_t token, const std::string &response)
return words.size() <= 3; return words.size() <= 3;
} }
bool ChatLLM::handleQuestionPrompt(int32_t token) bool LlamaCppModel::handleQuestionPrompt(int32_t token)
{ {
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "question prompt" << m_llmThread.objectName() << token; qDebug() << "question prompt" << m_llmThread.objectName() << token;
@ -950,7 +950,7 @@ bool ChatLLM::handleQuestionPrompt(int32_t token)
return !m_stopGenerating; return !m_stopGenerating;
} }
bool ChatLLM::handleQuestionResponse(int32_t token, const std::string &response) bool LlamaCppModel::handleQuestionResponse(int32_t token, const std::string &response)
{ {
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "question response" << m_llmThread.objectName() << token << response; qDebug() << "question response" << m_llmThread.objectName() << token << response;
@ -979,7 +979,7 @@ bool ChatLLM::handleQuestionResponse(int32_t token, const std::string &response)
return true; return true;
} }
void ChatLLM::generateQuestions(qint64 elapsed) void LlamaCppModel::generateQuestions(qint64 elapsed)
{ {
Q_ASSERT(isModelLoaded()); Q_ASSERT(isModelLoaded());
if (!isModelLoaded()) { if (!isModelLoaded()) {
@ -996,8 +996,8 @@ void ChatLLM::generateQuestions(qint64 elapsed)
emit generatingQuestions(); emit generatingQuestions();
m_questionResponse.clear(); m_questionResponse.clear();
auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); auto promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
auto promptFunc = std::bind(&ChatLLM::handleQuestionPrompt, this, std::placeholders::_1); auto promptFunc = std::bind(&LlamaCppModel::handleQuestionPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2); auto responseFunc = std::bind(&LlamaCppModel::handleQuestionResponse, this, std::placeholders::_1, std::placeholders::_2);
ModelBackend::PromptContext ctx = m_ctx; ModelBackend::PromptContext ctx = m_ctx;
QElapsedTimer totalTime; QElapsedTimer totalTime;
totalTime.start(); totalTime.start();
@ -1008,7 +1008,7 @@ void ChatLLM::generateQuestions(qint64 elapsed)
} }
bool ChatLLM::handleSystemPrompt(int32_t token) bool LlamaCppModel::handleSystemPrompt(int32_t token)
{ {
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "system prompt" << m_llmThread.objectName() << token << m_stopGenerating; qDebug() << "system prompt" << m_llmThread.objectName() << token << m_stopGenerating;
@ -1017,7 +1017,7 @@ bool ChatLLM::handleSystemPrompt(int32_t token)
return !m_stopGenerating; return !m_stopGenerating;
} }
bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token) bool LlamaCppModel::handleRestoreStateFromTextPrompt(int32_t token)
{ {
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "restore state from text prompt" << m_llmThread.objectName() << token << m_stopGenerating; qDebug() << "restore state from text prompt" << m_llmThread.objectName() << token << m_stopGenerating;
@ -1028,7 +1028,7 @@ bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
// this function serialized the cached model state to disk. // this function serialized the cached model state to disk.
// we want to also serialize n_ctx, and read it at load time. // we want to also serialize n_ctx, and read it at load time.
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) bool LlamaCppModel::serialize(QDataStream &stream, int version, bool serializeKV)
{ {
if (version > 1) { if (version > 1) {
stream << m_llModelType; stream << m_llModelType;
@ -1068,7 +1068,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
return stream.status() == QDataStream::Ok; return stream.status() == QDataStream::Ok;
} }
bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) bool LlamaCppModel::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV)
{ {
if (version > 1) { if (version > 1) {
int internalStateVersion; int internalStateVersion;
@ -1148,7 +1148,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV,
return stream.status() == QDataStream::Ok; return stream.status() == QDataStream::Ok;
} }
void ChatLLM::saveState() void LlamaCppModel::saveState()
{ {
if (!isModelLoaded() || m_pristineLoadedState) if (!isModelLoaded() || m_pristineLoadedState)
return; return;
@ -1170,7 +1170,7 @@ void ChatLLM::saveState()
m_llModelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data()))); m_llModelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
} }
void ChatLLM::restoreState() void LlamaCppModel::restoreState()
{ {
if (!isModelLoaded()) if (!isModelLoaded())
return; return;
@ -1211,7 +1211,7 @@ void ChatLLM::restoreState()
} }
} }
void ChatLLM::processSystemPrompt() void LlamaCppModel::processSystemPrompt()
{ {
Q_ASSERT(isModelLoaded()); Q_ASSERT(isModelLoaded());
if (!isModelLoaded() || m_processedSystemPrompt || m_restoreStateFromText || m_isServer) if (!isModelLoaded() || m_processedSystemPrompt || m_restoreStateFromText || m_isServer)
@ -1227,7 +1227,7 @@ void ChatLLM::processSystemPrompt()
m_stopGenerating = false; m_stopGenerating = false;
m_ctx = ModelBackend::PromptContext(); m_ctx = ModelBackend::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1); auto promptFunc = std::bind(&LlamaCppModel::handleSystemPrompt, this, std::placeholders::_1);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
@ -1268,7 +1268,7 @@ void ChatLLM::processSystemPrompt()
m_pristineLoadedState = false; m_pristineLoadedState = false;
} }
void ChatLLM::processRestoreStateFromText() void LlamaCppModel::processRestoreStateFromText()
{ {
Q_ASSERT(isModelLoaded()); Q_ASSERT(isModelLoaded());
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer) if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
@ -1280,7 +1280,7 @@ void ChatLLM::processRestoreStateFromText()
m_stopGenerating = false; m_stopGenerating = false;
m_ctx = ModelBackend::PromptContext(); m_ctx = ModelBackend::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1); auto promptFunc = std::bind(&LlamaCppModel::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);

View File

@ -27,7 +27,7 @@
using namespace Qt::Literals::StringLiterals; using namespace Qt::Literals::StringLiterals;
class Chat; class Chat;
class ChatLLM; class LlamaCppModel;
class QDataStream; class QDataStream;
// NOTE: values serialized to disk, do not change or reuse // NOTE: values serialized to disk, do not change or reuse
@ -43,10 +43,10 @@ struct LLModelInfo {
QFileInfo fileInfo; QFileInfo fileInfo;
std::optional<QString> fallbackReason; std::optional<QString> fallbackReason;
// NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which // NOTE: This does not store the model type or name on purpose as this is left for LlamaCppModel which
// must be able to serialize the information even if it is in the unloaded state // must be able to serialize the information even if it is in the unloaded state
void resetModel(ChatLLM *cllm, ModelBackend *model = nullptr); void resetModel(LlamaCppModel *cllm, ModelBackend *model = nullptr);
}; };
class TokenTimer : public QObject { class TokenTimer : public QObject {
@ -89,7 +89,7 @@ private:
quint32 m_tokens; quint32 m_tokens;
}; };
class ChatLLM : public QObject class LlamaCppModel : public QObject
{ {
Q_OBJECT Q_OBJECT
Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged) Q_PROPERTY(bool restoringFromText READ restoringFromText NOTIFY restoringFromTextChanged)
@ -98,8 +98,8 @@ class ChatLLM : public QObject
Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged) Q_PROPERTY(QString fallbackReason READ fallbackReason NOTIFY loadedModelInfoChanged)
public: public:
ChatLLM(Chat *parent, bool isServer = false); LlamaCppModel(Chat *parent, bool isServer = false);
virtual ~ChatLLM(); virtual ~LlamaCppModel();
void destroy(); void destroy();
static void destroyStore(); static void destroyStore();

View File

@ -87,7 +87,7 @@ int main(int argc, char *argv[])
int res = app.exec(); int res = app.exec();
// Make sure ChatLLM threads are joined before global destructors run. // Make sure LlamaCppModel threads are joined before global destructors run.
// Otherwise, we can get a heap-use-after-free inside of llama.cpp. // Otherwise, we can get a heap-use-after-free inside of llama.cpp.
ChatListModel::globalInstance()->destroyChats(); ChatListModel::globalInstance()->destroyChats();

View File

@ -71,7 +71,7 @@ static inline QJsonObject resultToJson(const ResultInfo &info)
} }
Server::Server(Chat *chat) Server::Server(Chat *chat)
: ChatLLM(chat, true /*isServer*/) : LlamaCppModel(chat, true /*isServer*/)
, m_chat(chat) , m_chat(chat)
, m_server(nullptr) , m_server(nullptr)
{ {

View File

@ -1,7 +1,7 @@
#ifndef SERVER_H #ifndef SERVER_H
#define SERVER_H #define SERVER_H
#include "chatllm.h" #include "llamacpp_model.h"
#include "database.h" #include "database.h"
#include <QHttpServerRequest> #include <QHttpServerRequest>
@ -13,7 +13,7 @@
class Chat; class Chat;
class QHttpServer; class QHttpServer;
class Server : public ChatLLM class Server : public LlamaCppModel
{ {
Q_OBJECT Q_OBJECT