replace setShouldBeLoaded with loadModelAsync/releaseModelAsync

Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
Jared Van Bortel 2024-08-09 11:05:06 -04:00
parent 05bd6042b6
commit 8fd9f01578
6 changed files with 61 additions and 74 deletions

View File

@ -74,7 +74,6 @@ void Chat::connectLLM()
connect(this, &Chat::promptRequested, m_llmodel, &LLModel::prompt, Qt::QueuedConnection);
connect(this, &Chat::modelChangeRequested, m_llmodel, &LLModel::modelChangeRequested, Qt::QueuedConnection);
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &LLModel::loadDefaultModel, Qt::QueuedConnection);
connect(this, &Chat::loadModelRequested, m_llmodel, &LLModel::loadModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &LLModel::generateName, Qt::QueuedConnection);
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LLModel::regenerateResponse, Qt::QueuedConnection);
@ -277,25 +276,23 @@ void Chat::markForDeletion()
void Chat::unloadModel()
{
stopGenerating();
m_llmodel->setShouldBeLoaded(false);
m_llmodel->releaseModelAsync();
}
void Chat::reloadModel()
{
m_llmodel->setShouldBeLoaded(true);
m_llmodel->loadModelAsync();
}
void Chat::forceUnloadModel()
{
stopGenerating();
m_llmodel->setForceUnloadModel(true);
m_llmodel->setShouldBeLoaded(false);
m_llmodel->releaseModelAsync(/*unload*/ true);
}
void Chat::forceReloadModel()
{
m_llmodel->setForceUnloadModel(true);
m_llmodel->setShouldBeLoaded(true);
m_llmodel->loadModelAsync(/*reload*/ true);
}
void Chat::trySwitchContextOfLoadedModel()

View File

@ -145,7 +145,6 @@ Q_SIGNALS:
void modelChangeRequested(const ModelInfo &modelInfo);
void modelInfoChanged();
void restoringFromTextChanged();
void loadDefaultModelRequested();
void loadModelRequested(const ModelInfo &modelInfo);
void generateNameRequested();
void modelLoadingErrorChanged();

View File

@ -106,7 +106,6 @@ LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer)
, m_promptTokens(0)
, m_restoringFromText(false)
, m_shouldBeLoaded(false)
, m_forceUnloadModel(false)
, m_markedForDeletion(false)
, m_stopGenerating(false)
, m_timer(nullptr)
@ -117,8 +116,10 @@ LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer)
, m_restoreStateFromText(false)
{
moveToThread(&m_llmThread);
connect(this, &LlamaCppModel::shouldBeLoadedChanged, this, &LlamaCppModel::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued
connect<void(LlamaCppModel::*)(bool), void(LlamaCppModel::*)(bool)>(
this, &LlamaCppModel::requestLoadModel, this, &LlamaCppModel::loadModel
);
connect(this, &LlamaCppModel::requestReleaseModel, this, &LlamaCppModel::releaseModel);
connect(this, &LlamaCppModel::trySwitchContextRequested, this, &LlamaCppModel::trySwitchContextOfLoadedModel,
Qt::QueuedConnection); // explicitly queued
connect(parent, &Chat::idChanged, this, &LlamaCppModel::handleChatIdChanged);
@ -170,8 +171,7 @@ void LlamaCppModel::handleForceMetalChanged(bool forceMetal)
m_forceMetal = forceMetal;
if (isModelLoaded() && m_shouldBeLoaded) {
m_reloadingToChangeVariant = true;
unloadModel();
reloadModel();
loadModel(/*reload*/ true);
m_reloadingToChangeVariant = false;
}
#endif
@ -181,22 +181,11 @@ void LlamaCppModel::handleDeviceChanged()
{
if (isModelLoaded() && m_shouldBeLoaded) {
m_reloadingToChangeVariant = true;
unloadModel();
reloadModel();
loadModel(/*reload*/ true);
m_reloadingToChangeVariant = false;
}
}
bool LlamaCppModel::loadDefaultModel()
{
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
if (defaultModel.filename().isEmpty()) {
emit modelLoadingError(u"Could not find any model to load"_qs);
return false;
}
return loadModel(defaultModel);
}
void LlamaCppModel::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
{
// We're trying to see if the store already has the model fully loaded that we wish to use
@ -820,13 +809,16 @@ bool LlamaCppModel::promptInternal(const QList<QString> &collectionList, const Q
return true;
}
void LlamaCppModel::setShouldBeLoaded(bool b)
void LlamaCppModel::loadModelAsync(bool reload)
{
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get();
#endif
m_shouldBeLoaded = b; // atomic
emit shouldBeLoadedChanged();
m_shouldBeLoaded = true; // atomic
emit requestLoadModel(reload);
}
void LlamaCppModel::releaseModelAsync(bool unload)
{
m_shouldBeLoaded = false; // atomic
emit requestReleaseModel(unload);
}
void LlamaCppModel::requestTrySwitchContext()
@ -835,23 +827,43 @@ void LlamaCppModel::requestTrySwitchContext()
emit trySwitchContextRequested(modelInfo());
}
void LlamaCppModel::handleShouldBeLoadedChanged()
void LlamaCppModel::loadModel(bool reload)
{
if (m_shouldBeLoaded)
reloadModel();
else
unloadModel();
Q_ASSERT(m_shouldBeLoaded);
if (m_isServer)
return; // server managed models directly
if (reload)
releaseModel(/*unload*/ true);
else if (isModelLoaded())
return; // already loaded
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "loadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
ModelInfo m = modelInfo();
if (m.name().isEmpty()) {
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
if (defaultModel.filename().isEmpty()) {
emit modelLoadingError(u"Could not find any model to load"_s);
return;
}
m = defaultModel;
}
loadModel(m);
}
void LlamaCppModel::unloadModel()
void LlamaCppModel::releaseModel(bool unload)
{
if (!isModelLoaded() || m_isServer)
return;
if (!m_forceUnloadModel || !m_shouldBeLoaded)
if (unload && m_shouldBeLoaded) {
// reloading the model, don't show unloaded status
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small positive value
} else {
emit modelLoadingPercentageChanged(0.0f);
else
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
}
if (!m_markedForDeletion)
saveState();
@ -860,33 +872,14 @@ void LlamaCppModel::unloadModel()
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
if (m_forceUnloadModel) {
if (unload) {
m_llModelInfo.resetModel(this);
m_forceUnloadModel = false;
}
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
m_pristineLoadedState = false;
}
void LlamaCppModel::reloadModel()
{
if (isModelLoaded() && m_forceUnloadModel)
unloadModel(); // we unload first if we are forcing an unload
if (isModelLoaded() || m_isServer)
return;
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
#endif
const ModelInfo m = modelInfo();
if (m.name().isEmpty())
loadDefaultModel();
else
loadModel(m);
}
void LlamaCppModel::generateName()
{
Q_ASSERT(isModelLoaded());

View File

@ -109,9 +109,9 @@ public:
void stopGenerating() override { m_stopGenerating = true; }
void setShouldBeLoaded(bool b) override;
void loadModelAsync(bool reload = false) override;
void releaseModelAsync(bool unload = false) override;
void requestTrySwitchContext() override;
void setForceUnloadModel(bool b) override { m_forceUnloadModel = b; }
void setMarkedForDeletion(bool b) override { m_markedForDeletion = b; }
void setModelInfo(const ModelInfo &info) override;
@ -147,14 +147,14 @@ public:
public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt) override;
bool loadDefaultModel() override;
bool loadModel(const ModelInfo &modelInfo) override;
void modelChangeRequested(const ModelInfo &modelInfo) override;
void generateName() override;
void processSystemPrompt() override;
Q_SIGNALS:
void shouldBeLoadedChanged();
void requestLoadModel(bool reload);
void requestReleaseModel(bool unload);
protected:
bool isModelLoaded() const;
@ -183,11 +183,10 @@ protected:
protected Q_SLOTS:
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
void unloadModel();
void reloadModel();
void loadModel(bool reload = false);
void releaseModel(bool unload = false);
void generateQuestions(qint64 elapsed);
void handleChatIdChanged(const QString &id);
void handleShouldBeLoadedChanged();
void handleThreadStarted();
void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged();
@ -197,11 +196,13 @@ private:
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
protected:
ModelBackend::PromptContext m_ctx;
// used by Server
quint32 m_promptTokens;
quint32 m_promptResponseTokens;
std::atomic<bool> m_shouldBeLoaded;
private:
ModelBackend::PromptContext m_ctx;
std::string m_response;
std::string m_nameResponse;
QString m_questionResponse;
@ -212,9 +213,7 @@ private:
QByteArray m_state;
QThread m_llmThread;
std::atomic<bool> m_stopGenerating;
std::atomic<bool> m_shouldBeLoaded;
std::atomic<bool> m_restoringFromText; // status indication
std::atomic<bool> m_forceUnloadModel;
std::atomic<bool> m_markedForDeletion;
bool m_isServer;
bool m_forceMetal;

View File

@ -30,9 +30,9 @@ public:
virtual void stopGenerating() = 0;
virtual void setShouldBeLoaded(bool b) = 0;
virtual void loadModelAsync(bool reload = false) = 0;
virtual void releaseModelAsync(bool unload = false) = 0;
virtual void requestTrySwitchContext() = 0;
virtual void setForceUnloadModel(bool b) = 0;
virtual void setMarkedForDeletion(bool b) = 0;
virtual void setModelInfo(const ModelInfo &info) = 0;
@ -45,7 +45,6 @@ public:
public Q_SLOTS:
virtual bool prompt(const QList<QString> &collectionList, const QString &prompt) = 0;
virtual bool loadDefaultModel() = 0;
virtual bool loadModel(const ModelInfo &modelInfo) = 0;
virtual void modelChangeRequested(const ModelInfo &modelInfo) = 0;
virtual void generateName() = 0;

View File

@ -352,7 +352,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
emit requestServerNewPromptResponsePair(actualPrompt); // blocks
// load the new model if necessary
setShouldBeLoaded(true);
m_shouldBeLoaded = true;
if (modelInfo.filename().isEmpty()) {
std::cerr << "ERROR: couldn't load default model " << modelRequested.toStdString() << std::endl;