replace setShouldBeLoaded with loadModelAsync/releaseModelAsync

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-08-09 11:05:06 -04:00
parent 05bd6042b6
commit 8fd9f01578
6 changed files with 61 additions and 74 deletions

View File

@ -74,7 +74,6 @@ void Chat::connectLLM()
connect(this, &Chat::promptRequested, m_llmodel, &LLModel::prompt, Qt::QueuedConnection); connect(this, &Chat::promptRequested, m_llmodel, &LLModel::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelChangeRequested, m_llmodel, &LLModel::modelChangeRequested, Qt::QueuedConnection); connect(this, &Chat::modelChangeRequested, m_llmodel, &LLModel::modelChangeRequested, Qt::QueuedConnection);
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &LLModel::loadDefaultModel, Qt::QueuedConnection);
connect(this, &Chat::loadModelRequested, m_llmodel, &LLModel::loadModel, Qt::QueuedConnection); connect(this, &Chat::loadModelRequested, m_llmodel, &LLModel::loadModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &LLModel::generateName, Qt::QueuedConnection); connect(this, &Chat::generateNameRequested, m_llmodel, &LLModel::generateName, Qt::QueuedConnection);
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LLModel::regenerateResponse, Qt::QueuedConnection); connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LLModel::regenerateResponse, Qt::QueuedConnection);
@ -277,25 +276,23 @@ void Chat::markForDeletion()
void Chat::unloadModel() void Chat::unloadModel()
{ {
stopGenerating(); stopGenerating();
m_llmodel->setShouldBeLoaded(false); m_llmodel->releaseModelAsync();
} }
void Chat::reloadModel() void Chat::reloadModel()
{ {
m_llmodel->setShouldBeLoaded(true); m_llmodel->loadModelAsync();
} }
void Chat::forceUnloadModel() void Chat::forceUnloadModel()
{ {
stopGenerating(); stopGenerating();
m_llmodel->setForceUnloadModel(true); m_llmodel->releaseModelAsync(/*unload*/ true);
m_llmodel->setShouldBeLoaded(false);
} }
void Chat::forceReloadModel() void Chat::forceReloadModel()
{ {
m_llmodel->setForceUnloadModel(true); m_llmodel->loadModelAsync(/*reload*/ true);
m_llmodel->setShouldBeLoaded(true);
} }
void Chat::trySwitchContextOfLoadedModel() void Chat::trySwitchContextOfLoadedModel()

View File

@ -145,7 +145,6 @@ Q_SIGNALS:
void modelChangeRequested(const ModelInfo &modelInfo); void modelChangeRequested(const ModelInfo &modelInfo);
void modelInfoChanged(); void modelInfoChanged();
void restoringFromTextChanged(); void restoringFromTextChanged();
void loadDefaultModelRequested();
void loadModelRequested(const ModelInfo &modelInfo); void loadModelRequested(const ModelInfo &modelInfo);
void generateNameRequested(); void generateNameRequested();
void modelLoadingErrorChanged(); void modelLoadingErrorChanged();

View File

@ -106,7 +106,6 @@ LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer)
, m_promptTokens(0) , m_promptTokens(0)
, m_restoringFromText(false) , m_restoringFromText(false)
, m_shouldBeLoaded(false) , m_shouldBeLoaded(false)
, m_forceUnloadModel(false)
, m_markedForDeletion(false) , m_markedForDeletion(false)
, m_stopGenerating(false) , m_stopGenerating(false)
, m_timer(nullptr) , m_timer(nullptr)
@ -117,8 +116,10 @@ LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer)
, m_restoreStateFromText(false) , m_restoreStateFromText(false)
{ {
moveToThread(&m_llmThread); moveToThread(&m_llmThread);
connect(this, &LlamaCppModel::shouldBeLoadedChanged, this, &LlamaCppModel::handleShouldBeLoadedChanged, connect<void(LlamaCppModel::*)(bool), void(LlamaCppModel::*)(bool)>(
Qt::QueuedConnection); // explicitly queued this, &LlamaCppModel::requestLoadModel, this, &LlamaCppModel::loadModel
);
connect(this, &LlamaCppModel::requestReleaseModel, this, &LlamaCppModel::releaseModel);
connect(this, &LlamaCppModel::trySwitchContextRequested, this, &LlamaCppModel::trySwitchContextOfLoadedModel, connect(this, &LlamaCppModel::trySwitchContextRequested, this, &LlamaCppModel::trySwitchContextOfLoadedModel,
Qt::QueuedConnection); // explicitly queued Qt::QueuedConnection); // explicitly queued
connect(parent, &Chat::idChanged, this, &LlamaCppModel::handleChatIdChanged); connect(parent, &Chat::idChanged, this, &LlamaCppModel::handleChatIdChanged);
@ -170,8 +171,7 @@ void LlamaCppModel::handleForceMetalChanged(bool forceMetal)
m_forceMetal = forceMetal; m_forceMetal = forceMetal;
if (isModelLoaded() && m_shouldBeLoaded) { if (isModelLoaded() && m_shouldBeLoaded) {
m_reloadingToChangeVariant = true; m_reloadingToChangeVariant = true;
unloadModel(); loadModel(/*reload*/ true);
reloadModel();
m_reloadingToChangeVariant = false; m_reloadingToChangeVariant = false;
} }
#endif #endif
@ -181,22 +181,11 @@ void LlamaCppModel::handleDeviceChanged()
{ {
if (isModelLoaded() && m_shouldBeLoaded) { if (isModelLoaded() && m_shouldBeLoaded) {
m_reloadingToChangeVariant = true; m_reloadingToChangeVariant = true;
unloadModel(); loadModel(/*reload*/ true);
reloadModel();
m_reloadingToChangeVariant = false; m_reloadingToChangeVariant = false;
} }
} }
bool LlamaCppModel::loadDefaultModel()
{
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
if (defaultModel.filename().isEmpty()) {
emit modelLoadingError(u"Could not find any model to load"_qs);
return false;
}
return loadModel(defaultModel);
}
void LlamaCppModel::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo) void LlamaCppModel::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
{ {
// We're trying to see if the store already has the model fully loaded that we wish to use // We're trying to see if the store already has the model fully loaded that we wish to use
@ -820,13 +809,16 @@ bool LlamaCppModel::promptInternal(const QList<QString> &collectionList, const Q
return true; return true;
} }
void LlamaCppModel::setShouldBeLoaded(bool b) void LlamaCppModel::loadModelAsync(bool reload)
{ {
#if defined(DEBUG_MODEL_LOADING) m_shouldBeLoaded = true; // atomic
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get(); emit requestLoadModel(reload);
#endif }
m_shouldBeLoaded = b; // atomic
emit shouldBeLoadedChanged(); void LlamaCppModel::releaseModelAsync(bool unload)
{
m_shouldBeLoaded = false; // atomic
emit requestReleaseModel(unload);
} }
void LlamaCppModel::requestTrySwitchContext() void LlamaCppModel::requestTrySwitchContext()
@ -835,23 +827,43 @@ void LlamaCppModel::requestTrySwitchContext()
emit trySwitchContextRequested(modelInfo()); emit trySwitchContextRequested(modelInfo());
} }
void LlamaCppModel::handleShouldBeLoadedChanged() void LlamaCppModel::loadModel(bool reload)
{ {
if (m_shouldBeLoaded) Q_ASSERT(m_shouldBeLoaded);
reloadModel(); if (m_isServer)
else return; // server managed models directly
unloadModel();
if (reload)
releaseModel(/*unload*/ true);
else if (isModelLoaded())
return; // already loaded
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "loadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
ModelInfo m = modelInfo();
if (m.name().isEmpty()) {
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
if (defaultModel.filename().isEmpty()) {
emit modelLoadingError(u"Could not find any model to load"_s);
return;
}
m = defaultModel;
}
loadModel(m);
} }
void LlamaCppModel::unloadModel() void LlamaCppModel::releaseModel(bool unload)
{ {
if (!isModelLoaded() || m_isServer) if (!isModelLoaded() || m_isServer)
return; return;
if (!m_forceUnloadModel || !m_shouldBeLoaded) if (unload && m_shouldBeLoaded) {
// reloading the model, don't show unloaded status
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small positive value
} else {
emit modelLoadingPercentageChanged(0.0f); emit modelLoadingPercentageChanged(0.0f);
else }
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
if (!m_markedForDeletion) if (!m_markedForDeletion)
saveState(); saveState();
@ -860,33 +872,14 @@ void LlamaCppModel::unloadModel()
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get(); qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif #endif
if (m_forceUnloadModel) { if (unload) {
m_llModelInfo.resetModel(this); m_llModelInfo.resetModel(this);
m_forceUnloadModel = false;
} }
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo)); LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_pristineLoadedState = false; m_pristineLoadedState = false;
} }
void LlamaCppModel::reloadModel()
{
if (isModelLoaded() && m_forceUnloadModel)
unloadModel(); // we unload first if we are forcing an unload
if (isModelLoaded() || m_isServer)
return;
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
const ModelInfo m = modelInfo();
if (m.name().isEmpty())
loadDefaultModel();
else
loadModel(m);
}
void LlamaCppModel::generateName() void LlamaCppModel::generateName()
{ {
Q_ASSERT(isModelLoaded()); Q_ASSERT(isModelLoaded());

View File

@ -109,9 +109,9 @@ public:
void stopGenerating() override { m_stopGenerating = true; } void stopGenerating() override { m_stopGenerating = true; }
void setShouldBeLoaded(bool b) override; void loadModelAsync(bool reload = false) override;
void releaseModelAsync(bool unload = false) override;
void requestTrySwitchContext() override; void requestTrySwitchContext() override;
void setForceUnloadModel(bool b) override { m_forceUnloadModel = b; }
void setMarkedForDeletion(bool b) override { m_markedForDeletion = b; } void setMarkedForDeletion(bool b) override { m_markedForDeletion = b; }
void setModelInfo(const ModelInfo &info) override; void setModelInfo(const ModelInfo &info) override;
@ -147,14 +147,14 @@ public:
public Q_SLOTS: public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt) override; bool prompt(const QList<QString> &collectionList, const QString &prompt) override;
bool loadDefaultModel() override;
bool loadModel(const ModelInfo &modelInfo) override; bool loadModel(const ModelInfo &modelInfo) override;
void modelChangeRequested(const ModelInfo &modelInfo) override; void modelChangeRequested(const ModelInfo &modelInfo) override;
void generateName() override; void generateName() override;
void processSystemPrompt() override; void processSystemPrompt() override;
Q_SIGNALS: Q_SIGNALS:
void shouldBeLoadedChanged(); void requestLoadModel(bool reload);
void requestReleaseModel(bool unload);
protected: protected:
bool isModelLoaded() const; bool isModelLoaded() const;
@ -183,11 +183,10 @@ protected:
protected Q_SLOTS: protected Q_SLOTS:
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo); void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
void unloadModel(); void loadModel(bool reload = false);
void reloadModel(); void releaseModel(bool unload = false);
void generateQuestions(qint64 elapsed); void generateQuestions(qint64 elapsed);
void handleChatIdChanged(const QString &id); void handleChatIdChanged(const QString &id);
void handleShouldBeLoadedChanged();
void handleThreadStarted(); void handleThreadStarted();
void handleForceMetalChanged(bool forceMetal); void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged(); void handleDeviceChanged();
@ -197,11 +196,13 @@ private:
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
protected: protected:
ModelBackend::PromptContext m_ctx; // used by Server
quint32 m_promptTokens; quint32 m_promptTokens;
quint32 m_promptResponseTokens; quint32 m_promptResponseTokens;
std::atomic<bool> m_shouldBeLoaded;
private: private:
ModelBackend::PromptContext m_ctx;
std::string m_response; std::string m_response;
std::string m_nameResponse; std::string m_nameResponse;
QString m_questionResponse; QString m_questionResponse;
@ -212,9 +213,7 @@ private:
QByteArray m_state; QByteArray m_state;
QThread m_llmThread; QThread m_llmThread;
std::atomic<bool> m_stopGenerating; std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded;
std::atomic<bool> m_restoringFromText; // status indication std::atomic<bool> m_restoringFromText; // status indication
std::atomic<bool> m_forceUnloadModel;
std::atomic<bool> m_markedForDeletion; std::atomic<bool> m_markedForDeletion;
bool m_isServer; bool m_isServer;
bool m_forceMetal; bool m_forceMetal;

View File

@ -30,9 +30,9 @@ public:
virtual void stopGenerating() = 0; virtual void stopGenerating() = 0;
virtual void setShouldBeLoaded(bool b) = 0; virtual void loadModelAsync(bool reload = false) = 0;
virtual void releaseModelAsync(bool unload = false) = 0;
virtual void requestTrySwitchContext() = 0; virtual void requestTrySwitchContext() = 0;
virtual void setForceUnloadModel(bool b) = 0;
virtual void setMarkedForDeletion(bool b) = 0; virtual void setMarkedForDeletion(bool b) = 0;
virtual void setModelInfo(const ModelInfo &info) = 0; virtual void setModelInfo(const ModelInfo &info) = 0;
@ -45,7 +45,6 @@ public:
public Q_SLOTS: public Q_SLOTS:
virtual bool prompt(const QList<QString> &collectionList, const QString &prompt) = 0; virtual bool prompt(const QList<QString> &collectionList, const QString &prompt) = 0;
virtual bool loadDefaultModel() = 0;
virtual bool loadModel(const ModelInfo &modelInfo) = 0; virtual bool loadModel(const ModelInfo &modelInfo) = 0;
virtual void modelChangeRequested(const ModelInfo &modelInfo) = 0; virtual void modelChangeRequested(const ModelInfo &modelInfo) = 0;
virtual void generateName() = 0; virtual void generateName() = 0;

View File

@ -352,7 +352,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
emit requestServerNewPromptResponsePair(actualPrompt); // blocks emit requestServerNewPromptResponsePair(actualPrompt); // blocks
// load the new model if necessary // load the new model if necessary
setShouldBeLoaded(true); m_shouldBeLoaded = true;
if (modelInfo.filename().isEmpty()) { if (modelInfo.filename().isEmpty()) {
std::cerr << "ERROR: couldn't load default model " << modelRequested.toStdString() << std::endl; std::cerr << "ERROR: couldn't load default model " << modelRequested.toStdString() << std::endl;