Much better memory mgmt for multi-threaded model loading/unloading.

This commit is contained in:
Adam Treat 2023-05-13 19:05:35 -04:00 committed by AT
parent 2989b74d43
commit ddc24acf33
6 changed files with 243 additions and 74 deletions

View File

@ -12,6 +12,7 @@ Chat::Chat(QObject *parent)
, m_creationDate(QDateTime::currentSecsSinceEpoch()) , m_creationDate(QDateTime::currentSecsSinceEpoch())
, m_llmodel(new ChatLLM(this)) , m_llmodel(new ChatLLM(this))
, m_isServer(false) , m_isServer(false)
, m_shouldDeleteLater(false)
{ {
connectLLM(); connectLLM();
} }
@ -25,6 +26,7 @@ Chat::Chat(bool isServer, QObject *parent)
, m_creationDate(QDateTime::currentSecsSinceEpoch()) , m_creationDate(QDateTime::currentSecsSinceEpoch())
, m_llmodel(new Server(this)) , m_llmodel(new Server(this))
, m_isServer(true) , m_isServer(true)
, m_shouldDeleteLater(false)
{ {
connectLLM(); connectLLM();
} }
@ -43,6 +45,7 @@ void Chat::connectLLM()
// Should be in different threads // Should be in different threads
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::handleModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
@ -55,8 +58,6 @@ void Chat::connectLLM()
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection); connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection); connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection);
connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection); connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection);
connect(this, &Chat::unloadModelRequested, m_llmodel, &ChatLLM::unloadModel, Qt::QueuedConnection);
connect(this, &Chat::reloadModelRequested, m_llmodel, &ChatLLM::reloadModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection); connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
// The following are blocking operations and will block the gui thread, therefore must be fast // The following are blocking operations and will block the gui thread, therefore must be fast
@ -122,6 +123,12 @@ void Chat::handleResponseChanged()
emit responseChanged(); emit responseChanged();
} }
void Chat::handleModelLoadedChanged()
{
if (m_shouldDeleteLater)
deleteLater();
}
void Chat::responseStarted() void Chat::responseStarted()
{ {
m_responseInProgress = true; m_responseInProgress = true;
@ -180,15 +187,26 @@ void Chat::loadModel(const QString &modelName)
emit loadModelRequested(modelName); emit loadModelRequested(modelName);
} }
void Chat::unloadAndDeleteLater()
{
if (!isModelLoaded()) {
deleteLater();
return;
}
m_shouldDeleteLater = true;
unloadModel();
}
void Chat::unloadModel() void Chat::unloadModel()
{ {
stopGenerating(); stopGenerating();
emit unloadModelRequested(); m_llmodel->setShouldBeLoaded(false);
} }
void Chat::reloadModel() void Chat::reloadModel()
{ {
emit reloadModelRequested(m_savedModelName); m_llmodel->setShouldBeLoaded(true);
} }
void Chat::generatedNameChanged() void Chat::generatedNameChanged()
@ -236,12 +254,10 @@ bool Chat::deserialize(QDataStream &stream, int version)
stream >> m_userName; stream >> m_userName;
emit nameChanged(); emit nameChanged();
stream >> m_savedModelName; stream >> m_savedModelName;
// Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so // Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
// unfortunately, we cannot deserialize these // unfortunately, we cannot deserialize these
if (version < 2 && m_savedModelName.contains("gpt4all-j")) if (version < 2 && m_savedModelName.contains("gpt4all-j"))
return false; return false;
if (!m_llmodel->deserialize(stream, version)) if (!m_llmodel->deserialize(stream, version))
return false; return false;
if (!m_chatModel->deserialize(stream, version)) if (!m_chatModel->deserialize(stream, version))

View File

@ -58,6 +58,7 @@ public:
void loadModel(const QString &modelName); void loadModel(const QString &modelName);
void unloadModel(); void unloadModel();
void reloadModel(); void reloadModel();
void unloadAndDeleteLater();
qint64 creationDate() const { return m_creationDate; } qint64 creationDate() const { return m_creationDate; }
bool serialize(QDataStream &stream, int version) const; bool serialize(QDataStream &stream, int version) const;
@ -87,8 +88,6 @@ Q_SIGNALS:
void recalcChanged(); void recalcChanged();
void loadDefaultModelRequested(); void loadDefaultModelRequested();
void loadModelRequested(const QString &modelName); void loadModelRequested(const QString &modelName);
void unloadModelRequested();
void reloadModelRequested(const QString &modelName);
void generateNameRequested(); void generateNameRequested();
void modelListChanged(); void modelListChanged();
void modelLoadingError(const QString &error); void modelLoadingError(const QString &error);
@ -96,6 +95,7 @@ Q_SIGNALS:
private Q_SLOTS: private Q_SLOTS:
void handleResponseChanged(); void handleResponseChanged();
void handleModelLoadedChanged();
void responseStarted(); void responseStarted();
void responseStopped(); void responseStopped();
void generatedNameChanged(); void generatedNameChanged();
@ -112,6 +112,7 @@ private:
qint64 m_creationDate; qint64 m_creationDate;
ChatLLM *m_llmodel; ChatLLM *m_llmodel;
bool m_isServer; bool m_isServer;
bool m_shouldDeleteLater;
}; };
#endif // CHAT_H #endif // CHAT_H

View File

@ -40,6 +40,7 @@ void ChatListModel::setShouldSaveChats(bool b)
void ChatListModel::removeChatFile(Chat *chat) const void ChatListModel::removeChatFile(Chat *chat) const
{ {
Q_ASSERT(chat != m_serverChat);
const QString savePath = Download::globalInstance()->downloadLocalModelsPath(); const QString savePath = Download::globalInstance()->downloadLocalModelsPath();
QFile file(savePath + "/gpt4all-" + chat->id() + ".chat"); QFile file(savePath + "/gpt4all-" + chat->id() + ".chat");
if (!file.exists()) if (!file.exists())
@ -58,6 +59,8 @@ void ChatListModel::saveChats() const
timer.start(); timer.start();
const QString savePath = Download::globalInstance()->downloadLocalModelsPath(); const QString savePath = Download::globalInstance()->downloadLocalModelsPath();
for (Chat *chat : m_chats) { for (Chat *chat : m_chats) {
if (chat == m_serverChat)
continue;
QString fileName = "gpt4all-" + chat->id() + ".chat"; QString fileName = "gpt4all-" + chat->id() + ".chat";
QFile file(savePath + "/" + fileName); QFile file(savePath + "/" + fileName);
bool success = file.open(QIODevice::WriteOnly); bool success = file.open(QIODevice::WriteOnly);

View File

@ -125,6 +125,7 @@ public:
Q_INVOKABLE void removeChat(Chat* chat) Q_INVOKABLE void removeChat(Chat* chat)
{ {
Q_ASSERT(chat != m_serverChat);
if (!m_chats.contains(chat)) { if (!m_chats.contains(chat)) {
qWarning() << "WARNING: Removing chat failed with id" << chat->id(); qWarning() << "WARNING: Removing chat failed with id" << chat->id();
return; return;
@ -138,11 +139,11 @@ public:
} }
const int index = m_chats.indexOf(chat); const int index = m_chats.indexOf(chat);
if (m_chats.count() < 2) { if (m_chats.count() < 3 /*m_serverChat included*/) {
addChat(); addChat();
} else { } else {
int nextIndex; int nextIndex;
if (index == m_chats.count() - 1) if (index == m_chats.count() - 2 /*m_serverChat is last*/)
nextIndex = index - 1; nextIndex = index - 1;
else else
nextIndex = index + 1; nextIndex = index + 1;
@ -155,7 +156,7 @@ public:
beginRemoveRows(QModelIndex(), newIndex, newIndex); beginRemoveRows(QModelIndex(), newIndex, newIndex);
m_chats.removeAll(chat); m_chats.removeAll(chat);
endRemoveRows(); endRemoveRows();
delete chat; chat->unloadAndDeleteLater();
} }
Chat *currentChat() const Chat *currentChat() const
@ -170,7 +171,7 @@ public:
return; return;
} }
if (m_currentChat && m_currentChat->isModelLoaded()) if (m_currentChat)
m_currentChat->unloadModel(); m_currentChat->unloadModel();
m_currentChat = chat; m_currentChat = chat;

View File

@ -15,6 +15,7 @@
#include <fstream> #include <fstream>
//#define DEBUG //#define DEBUG
//#define DEBUG_MODEL_LOADING
#define MPT_INTERNAL_STATE_VERSION 0 #define MPT_INTERNAL_STATE_VERSION 0
#define GPTJ_INTERNAL_STATE_VERSION 0 #define GPTJ_INTERNAL_STATE_VERSION 0
@ -37,9 +38,51 @@ static QString modelFilePath(const QString &modelName)
return QString(); return QString();
} }
class LLModelStore {
public:
static LLModelStore *globalInstance();
LLModelInfo acquireModel(); // will block until llmodel is ready
void releaseModel(const LLModelInfo &info); // must be called when you are done
private:
LLModelStore()
{
// seed with empty model
m_availableModels.append(LLModelInfo());
}
~LLModelStore() {}
QVector<LLModelInfo> m_availableModels;
QMutex m_mutex;
QWaitCondition m_condition;
friend class MyLLModelStore;
};
class MyLLModelStore : public LLModelStore { };
Q_GLOBAL_STATIC(MyLLModelStore, storeInstance)
LLModelStore *LLModelStore::globalInstance()
{
return storeInstance();
}
LLModelInfo LLModelStore::acquireModel()
{
QMutexLocker locker(&m_mutex);
while (m_availableModels.isEmpty())
m_condition.wait(locker.mutex());
return m_availableModels.takeFirst();
}
void LLModelStore::releaseModel(const LLModelInfo &info)
{
QMutexLocker locker(&m_mutex);
m_availableModels.append(info);
Q_ASSERT(m_availableModels.count() < 2);
m_condition.wakeAll();
}
ChatLLM::ChatLLM(Chat *parent) ChatLLM::ChatLLM(Chat *parent)
: QObject{nullptr} : QObject{nullptr}
, m_llmodel(nullptr)
, m_promptResponseTokens(0) , m_promptResponseTokens(0)
, m_promptTokens(0) , m_promptTokens(0)
, m_responseLogits(0) , m_responseLogits(0)
@ -49,6 +92,7 @@ ChatLLM::ChatLLM(Chat *parent)
moveToThread(&m_llmThread); moveToThread(&m_llmThread);
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded); connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, Qt::QueuedConnection);
connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged); connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::threadStarted); connect(&m_llmThread, &QThread::started, this, &ChatLLM::threadStarted);
m_llmThread.setObjectName(m_chat->id()); m_llmThread.setObjectName(m_chat->id());
@ -59,7 +103,13 @@ ChatLLM::~ChatLLM()
{ {
m_llmThread.quit(); m_llmThread.quit();
m_llmThread.wait(); m_llmThread.wait();
delete m_llmodel;
// The only time we should have a model loaded here is on shutdown
// as we explicitly unload the model in all other circumstances
if (isModelLoaded()) {
delete m_modelInfo.model;
m_modelInfo.model = nullptr;
}
} }
bool ChatLLM::loadDefaultModel() bool ChatLLM::loadDefaultModel()
@ -76,50 +126,103 @@ bool ChatLLM::loadDefaultModel()
bool ChatLLM::loadModel(const QString &modelName) bool ChatLLM::loadModel(const QString &modelName)
{ {
// This is a complicated method because N different possible threads are interested in the outcome
// of this method. Why? Because we have a main/gui thread trying to monitor the state of N different
// possible chat threads all vying for a single resource - the currently loaded model - as the user
// switches back and forth between chats. It is important for our main/gui thread to never block
// but simultaneously always have up2date information with regards to which chat has the model loaded
// and what the type and name of that model is. I've tried to comment extensively in this method
// to provide an overview of what we're doing here.
// We're already loaded with this model
if (isModelLoaded() && m_modelName == modelName) if (isModelLoaded() && m_modelName == modelName)
return true; return true;
if (isModelLoaded()) { QString filePath = modelFilePath(modelName);
QFileInfo fileInfo(filePath);
// We have a live model, but it isn't the one we want
bool alreadyAcquired = isModelLoaded();
if (alreadyAcquired) {
resetContextProtected(); resetContextProtected();
delete m_llmodel; #if defined(DEBUG_MODEL_LOADING)
m_llmodel = nullptr; qDebug() << "already acquired model deleted" << m_chat->id() << m_modelInfo.model;
#endif
delete m_modelInfo.model;
m_modelInfo.model = nullptr;
emit isModelLoadedChanged(); emit isModelLoadedChanged();
} else {
// This is a blocking call that tries to retrieve the model we need from the model store.
// If it succeeds, then we just have to restore state. If the store has never had a model
// returned to it, then the modelInfo.model pointer should be null which will happen on startup
m_modelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_chat->id() << m_modelInfo.model;
#endif
// At this point it is possible that while we were blocked waiting to acquire the model from the
// store, that our state was changed to not be loaded. If this is the case, release the model
// back into the store and quit loading
if (!m_shouldBeLoaded) {
qDebug() << "no longer need model" << m_chat->id() << m_modelInfo.model;
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
m_modelInfo = LLModelInfo();
emit isModelLoadedChanged();
return false;
}
// Check if the store just gave us exactly the model we were looking for
if (m_modelInfo.model && m_modelInfo.fileInfo == fileInfo) {
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_chat->id() << m_modelInfo.model;
#endif
restoreState();
emit isModelLoadedChanged();
return true;
} else {
// Release the memory since we have to switch to a different model.
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "deleting model" << m_chat->id() << m_modelInfo.model;
#endif
delete m_modelInfo.model;
m_modelInfo.model = nullptr;
}
} }
bool isGPTJ = false; // Guarantee we've released the previous models memory
bool isMPT = false; Q_ASSERT(!m_modelInfo.model);
QString filePath = modelFilePath(modelName);
QFileInfo info(filePath);
if (info.exists()) {
// Store the file info in the modelInfo in case we have an error loading
m_modelInfo.fileInfo = fileInfo;
if (fileInfo.exists()) {
auto fin = std::ifstream(filePath.toStdString(), std::ios::binary); auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
uint32_t magic; uint32_t magic;
fin.read((char *) &magic, sizeof(magic)); fin.read((char *) &magic, sizeof(magic));
fin.seekg(0); fin.seekg(0);
fin.close(); fin.close();
isGPTJ = magic == 0x67676d6c; const bool isGPTJ = magic == 0x67676d6c;
isMPT = magic == 0x67676d6d; const bool isMPT = magic == 0x67676d6d;
if (isGPTJ) { if (isGPTJ) {
m_modelType = ModelType::GPTJ_; m_modelType = LLModelType::GPTJ_;
m_llmodel = new GPTJ; m_modelInfo.model = new GPTJ;
m_llmodel->loadModel(filePath.toStdString()); m_modelInfo.model->loadModel(filePath.toStdString());
} else if (isMPT) { } else if (isMPT) {
m_modelType = ModelType::MPT_; m_modelType = LLModelType::MPT_;
m_llmodel = new MPT; m_modelInfo.model = new MPT;
m_llmodel->loadModel(filePath.toStdString()); m_modelInfo.model->loadModel(filePath.toStdString());
} else { } else {
m_modelType = ModelType::LLAMA_; m_modelType = LLModelType::LLAMA_;
m_llmodel = new LLamaModel; m_modelInfo.model = new LLamaModel;
m_llmodel->loadModel(filePath.toStdString()); m_modelInfo.model->loadModel(filePath.toStdString());
} }
#if defined(DEBUG_MODEL_LOADING)
restoreState(); qDebug() << "new model" << m_chat->id() << m_modelInfo.model;
#endif
#if defined(DEBUG) restoreState();
qDebug() << "chatllm modelLoadedChanged" << m_chat->id(); #if defined(DEBUG)
fflush(stdout); qDebug() << "modelLoadedChanged" << m_chat->id();
fflush(stdout);
#endif #endif
emit isModelLoadedChanged(); emit isModelLoadedChanged();
static bool isFirstLoad = true; static bool isFirstLoad = true;
@ -129,19 +232,20 @@ bool ChatLLM::loadModel(const QString &modelName)
} else } else
emit sendModelLoaded(); emit sendModelLoaded();
} else { } else {
LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
const QString error = QString("Could not find model %1").arg(modelName); const QString error = QString("Could not find model %1").arg(modelName);
emit modelLoadingError(error); emit modelLoadingError(error);
} }
if (m_llmodel) if (m_modelInfo.model)
setModelName(info.completeBaseName().remove(0, 5)); // remove the ggml- prefix setModelName(fileInfo.completeBaseName().remove(0, 5)); // remove the ggml- prefix
return m_llmodel; return m_modelInfo.model;
} }
bool ChatLLM::isModelLoaded() const bool ChatLLM::isModelLoaded() const
{ {
return m_llmodel && m_llmodel->isModelLoaded(); return m_modelInfo.model && m_modelInfo.model->isModelLoaded();
} }
void ChatLLM::regenerateResponse() void ChatLLM::regenerateResponse()
@ -226,7 +330,7 @@ bool ChatLLM::handlePrompt(int32_t token)
// m_promptResponseTokens and m_responseLogits are related to last prompt/response not // m_promptResponseTokens and m_responseLogits are related to last prompt/response not
// the entire context window which we can reset on regenerate prompt // the entire context window which we can reset on regenerate prompt
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "chatllm prompt process" << m_chat->id() << token; qDebug() << "prompt process" << m_chat->id() << token;
#endif #endif
++m_promptTokens; ++m_promptTokens;
++m_promptResponseTokens; ++m_promptResponseTokens;
@ -287,12 +391,12 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
m_ctx.n_batch = n_batch; m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty; m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens; m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llmodel->setThreadCount(n_threads); m_modelInfo.model->setThreadCount(n_threads);
#if defined(DEBUG) #if defined(DEBUG)
printf("%s", qPrintable(instructPrompt)); printf("%s", qPrintable(instructPrompt));
fflush(stdout); fflush(stdout);
#endif #endif
m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx); m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
#if defined(DEBUG) #if defined(DEBUG)
printf("\n"); printf("\n");
fflush(stdout); fflush(stdout);
@ -307,26 +411,55 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
return true; return true;
} }
void ChatLLM::setShouldBeLoaded(bool b)
{
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "setShouldBeLoaded" << m_chat->id() << b << m_modelInfo.model;
#endif
m_shouldBeLoaded = b; // atomic
emit shouldBeLoadedChanged();
}
void ChatLLM::handleShouldBeLoadedChanged()
{
if (m_shouldBeLoaded)
reloadModel();
else
unloadModel();
}
void ChatLLM::forceUnloadModel()
{
m_shouldBeLoaded = false; // atomic
unloadModel();
}
void ChatLLM::unloadModel() void ChatLLM::unloadModel()
{ {
#if defined(DEBUG) if (!isModelLoaded())
qDebug() << "chatllm unloadModel" << m_chat->id(); return;
#endif
saveState(); saveState();
delete m_llmodel; #if defined(DEBUG_MODEL_LOADING)
m_llmodel = nullptr; qDebug() << "unloadModel" << m_chat->id() << m_modelInfo.model;
#endif
LLModelStore::globalInstance()->releaseModel(m_modelInfo);
m_modelInfo = LLModelInfo();
emit isModelLoadedChanged(); emit isModelLoadedChanged();
} }
void ChatLLM::reloadModel(const QString &modelName) void ChatLLM::reloadModel()
{ {
#if defined(DEBUG) if (isModelLoaded())
qDebug() << "chatllm reloadModel" << m_chat->id(); return;
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "reloadModel" << m_chat->id() << m_modelInfo.model;
#endif #endif
if (modelName.isEmpty()) { if (m_modelName.isEmpty()) {
loadDefaultModel(); loadDefaultModel();
} else { } else {
loadModel(modelName); loadModel(m_modelName);
} }
} }
@ -348,7 +481,7 @@ void ChatLLM::generateName()
printf("%s", qPrintable(instructPrompt)); printf("%s", qPrintable(instructPrompt));
fflush(stdout); fflush(stdout);
#endif #endif
m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx); m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
#if defined(DEBUG) #if defined(DEBUG)
printf("\n"); printf("\n");
fflush(stdout); fflush(stdout);
@ -415,7 +548,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
QByteArray compressed = qCompress(m_state); QByteArray compressed = qCompress(m_state);
stream << compressed; stream << compressed;
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "chatllm serialize" << m_chat->id() << m_state.size(); qDebug() << "serialize" << m_chat->id() << m_state.size();
#endif #endif
return stream.status() == QDataStream::Ok; return stream.status() == QDataStream::Ok;
} }
@ -452,7 +585,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
stream >> m_state; stream >> m_state;
} }
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "chatllm deserialize" << m_chat->id(); qDebug() << "deserialize" << m_chat->id();
#endif #endif
return stream.status() == QDataStream::Ok; return stream.status() == QDataStream::Ok;
} }
@ -462,12 +595,12 @@ void ChatLLM::saveState()
if (!isModelLoaded()) if (!isModelLoaded())
return; return;
const size_t stateSize = m_llmodel->stateSize(); const size_t stateSize = m_modelInfo.model->stateSize();
m_state.resize(stateSize); m_state.resize(stateSize);
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "chatllm saveState" << m_chat->id() << "size:" << m_state.size(); qDebug() << "saveState" << m_chat->id() << "size:" << m_state.size();
#endif #endif
m_llmodel->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data()))); m_modelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
} }
void ChatLLM::restoreState() void ChatLLM::restoreState()
@ -476,9 +609,9 @@ void ChatLLM::restoreState()
return; return;
#if defined(DEBUG) #if defined(DEBUG)
qDebug() << "chatllm restoreState" << m_chat->id() << "size:" << m_state.size(); qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size();
#endif #endif
m_llmodel->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data()))); m_modelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_state.clear(); m_state.clear();
m_state.resize(0); m_state.resize(0);
} }

View File

@ -3,9 +3,23 @@
#include <QObject> #include <QObject>
#include <QThread> #include <QThread>
#include <QFileInfo>
#include "../gpt4all-backend/llmodel.h" #include "../gpt4all-backend/llmodel.h"
enum LLModelType {
MPT_,
GPTJ_,
LLAMA_
};
struct LLModelInfo {
LLModel *model = nullptr;
QFileInfo fileInfo;
// NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which
// must be able to serialize the information even if it is in the unloaded state
};
class Chat; class Chat;
class ChatLLM : public QObject class ChatLLM : public QObject
{ {
@ -17,12 +31,6 @@ class ChatLLM : public QObject
Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged) Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged)
public: public:
enum ModelType {
MPT_,
GPTJ_,
LLAMA_
};
ChatLLM(Chat *parent); ChatLLM(Chat *parent);
virtual ~ChatLLM(); virtual ~ChatLLM();
@ -33,6 +41,9 @@ public:
void stopGenerating() { m_stopGenerating = true; } void stopGenerating() { m_stopGenerating = true; }
bool shouldBeLoaded() const { return m_shouldBeLoaded; }
void setShouldBeLoaded(bool b);
QString response() const; QString response() const;
QString modelName() const; QString modelName() const;
@ -52,10 +63,12 @@ public Q_SLOTS:
bool loadDefaultModel(); bool loadDefaultModel();
bool loadModel(const QString &modelName); bool loadModel(const QString &modelName);
void modelNameChangeRequested(const QString &modelName); void modelNameChangeRequested(const QString &modelName);
void forceUnloadModel();
void unloadModel(); void unloadModel();
void reloadModel(const QString &modelName); void reloadModel();
void generateName(); void generateName();
void handleChatIdChanged(); void handleChatIdChanged();
void handleShouldBeLoadedChanged();
Q_SIGNALS: Q_SIGNALS:
void isModelLoadedChanged(); void isModelLoadedChanged();
@ -71,6 +84,7 @@ Q_SIGNALS:
void generatedNameChanged(); void generatedNameChanged();
void stateChanged(); void stateChanged();
void threadStarted(); void threadStarted();
void shouldBeLoadedChanged();
protected: protected:
LLModel::PromptContext m_ctx; LLModel::PromptContext m_ctx;
@ -89,16 +103,17 @@ private:
void restoreState(); void restoreState();
private: private:
LLModel *m_llmodel; LLModelInfo m_modelInfo;
LLModelType m_modelType;
std::string m_response; std::string m_response;
std::string m_nameResponse; std::string m_nameResponse;
quint32 m_responseLogits; quint32 m_responseLogits;
QString m_modelName; QString m_modelName;
ModelType m_modelType;
Chat *m_chat; Chat *m_chat;
QByteArray m_state; QByteArray m_state;
QThread m_llmThread; QThread m_llmThread;
std::atomic<bool> m_stopGenerating; std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded;
bool m_isRecalc; bool m_isRecalc;
}; };