mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-25 06:53:05 +00:00
replace setShouldBeLoaded with loadModelAsync/releaseModelAsync
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
05bd6042b6
commit
8fd9f01578
@ -74,7 +74,6 @@ void Chat::connectLLM()
|
||||
|
||||
connect(this, &Chat::promptRequested, m_llmodel, &LLModel::prompt, Qt::QueuedConnection);
|
||||
connect(this, &Chat::modelChangeRequested, m_llmodel, &LLModel::modelChangeRequested, Qt::QueuedConnection);
|
||||
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &LLModel::loadDefaultModel, Qt::QueuedConnection);
|
||||
connect(this, &Chat::loadModelRequested, m_llmodel, &LLModel::loadModel, Qt::QueuedConnection);
|
||||
connect(this, &Chat::generateNameRequested, m_llmodel, &LLModel::generateName, Qt::QueuedConnection);
|
||||
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LLModel::regenerateResponse, Qt::QueuedConnection);
|
||||
@ -277,25 +276,23 @@ void Chat::markForDeletion()
|
||||
void Chat::unloadModel()
|
||||
{
|
||||
stopGenerating();
|
||||
m_llmodel->setShouldBeLoaded(false);
|
||||
m_llmodel->releaseModelAsync();
|
||||
}
|
||||
|
||||
void Chat::reloadModel()
|
||||
{
|
||||
m_llmodel->setShouldBeLoaded(true);
|
||||
m_llmodel->loadModelAsync();
|
||||
}
|
||||
|
||||
void Chat::forceUnloadModel()
|
||||
{
|
||||
stopGenerating();
|
||||
m_llmodel->setForceUnloadModel(true);
|
||||
m_llmodel->setShouldBeLoaded(false);
|
||||
m_llmodel->releaseModelAsync(/*unload*/ true);
|
||||
}
|
||||
|
||||
void Chat::forceReloadModel()
|
||||
{
|
||||
m_llmodel->setForceUnloadModel(true);
|
||||
m_llmodel->setShouldBeLoaded(true);
|
||||
m_llmodel->loadModelAsync(/*reload*/ true);
|
||||
}
|
||||
|
||||
void Chat::trySwitchContextOfLoadedModel()
|
||||
|
@ -145,7 +145,6 @@ Q_SIGNALS:
|
||||
void modelChangeRequested(const ModelInfo &modelInfo);
|
||||
void modelInfoChanged();
|
||||
void restoringFromTextChanged();
|
||||
void loadDefaultModelRequested();
|
||||
void loadModelRequested(const ModelInfo &modelInfo);
|
||||
void generateNameRequested();
|
||||
void modelLoadingErrorChanged();
|
||||
|
@ -106,7 +106,6 @@ LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer)
|
||||
, m_promptTokens(0)
|
||||
, m_restoringFromText(false)
|
||||
, m_shouldBeLoaded(false)
|
||||
, m_forceUnloadModel(false)
|
||||
, m_markedForDeletion(false)
|
||||
, m_stopGenerating(false)
|
||||
, m_timer(nullptr)
|
||||
@ -117,8 +116,10 @@ LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer)
|
||||
, m_restoreStateFromText(false)
|
||||
{
|
||||
moveToThread(&m_llmThread);
|
||||
connect(this, &LlamaCppModel::shouldBeLoadedChanged, this, &LlamaCppModel::handleShouldBeLoadedChanged,
|
||||
Qt::QueuedConnection); // explicitly queued
|
||||
connect<void(LlamaCppModel::*)(bool), void(LlamaCppModel::*)(bool)>(
|
||||
this, &LlamaCppModel::requestLoadModel, this, &LlamaCppModel::loadModel
|
||||
);
|
||||
connect(this, &LlamaCppModel::requestReleaseModel, this, &LlamaCppModel::releaseModel);
|
||||
connect(this, &LlamaCppModel::trySwitchContextRequested, this, &LlamaCppModel::trySwitchContextOfLoadedModel,
|
||||
Qt::QueuedConnection); // explicitly queued
|
||||
connect(parent, &Chat::idChanged, this, &LlamaCppModel::handleChatIdChanged);
|
||||
@ -170,8 +171,7 @@ void LlamaCppModel::handleForceMetalChanged(bool forceMetal)
|
||||
m_forceMetal = forceMetal;
|
||||
if (isModelLoaded() && m_shouldBeLoaded) {
|
||||
m_reloadingToChangeVariant = true;
|
||||
unloadModel();
|
||||
reloadModel();
|
||||
loadModel(/*reload*/ true);
|
||||
m_reloadingToChangeVariant = false;
|
||||
}
|
||||
#endif
|
||||
@ -181,22 +181,11 @@ void LlamaCppModel::handleDeviceChanged()
|
||||
{
|
||||
if (isModelLoaded() && m_shouldBeLoaded) {
|
||||
m_reloadingToChangeVariant = true;
|
||||
unloadModel();
|
||||
reloadModel();
|
||||
loadModel(/*reload*/ true);
|
||||
m_reloadingToChangeVariant = false;
|
||||
}
|
||||
}
|
||||
|
||||
bool LlamaCppModel::loadDefaultModel()
|
||||
{
|
||||
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
|
||||
if (defaultModel.filename().isEmpty()) {
|
||||
emit modelLoadingError(u"Could not find any model to load"_qs);
|
||||
return false;
|
||||
}
|
||||
return loadModel(defaultModel);
|
||||
}
|
||||
|
||||
void LlamaCppModel::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
||||
{
|
||||
// We're trying to see if the store already has the model fully loaded that we wish to use
|
||||
@ -820,13 +809,16 @@ bool LlamaCppModel::promptInternal(const QList<QString> &collectionList, const Q
|
||||
return true;
|
||||
}
|
||||
|
||||
void LlamaCppModel::setShouldBeLoaded(bool b)
|
||||
void LlamaCppModel::loadModelAsync(bool reload)
|
||||
{
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get();
|
||||
#endif
|
||||
m_shouldBeLoaded = b; // atomic
|
||||
emit shouldBeLoadedChanged();
|
||||
m_shouldBeLoaded = true; // atomic
|
||||
emit requestLoadModel(reload);
|
||||
}
|
||||
|
||||
void LlamaCppModel::releaseModelAsync(bool unload)
|
||||
{
|
||||
m_shouldBeLoaded = false; // atomic
|
||||
emit requestReleaseModel(unload);
|
||||
}
|
||||
|
||||
void LlamaCppModel::requestTrySwitchContext()
|
||||
@ -835,23 +827,43 @@ void LlamaCppModel::requestTrySwitchContext()
|
||||
emit trySwitchContextRequested(modelInfo());
|
||||
}
|
||||
|
||||
void LlamaCppModel::handleShouldBeLoadedChanged()
|
||||
void LlamaCppModel::loadModel(bool reload)
|
||||
{
|
||||
if (m_shouldBeLoaded)
|
||||
reloadModel();
|
||||
else
|
||||
unloadModel();
|
||||
Q_ASSERT(m_shouldBeLoaded);
|
||||
if (m_isServer)
|
||||
return; // server managed models directly
|
||||
|
||||
if (reload)
|
||||
releaseModel(/*unload*/ true);
|
||||
else if (isModelLoaded())
|
||||
return; // already loaded
|
||||
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "loadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
ModelInfo m = modelInfo();
|
||||
if (m.name().isEmpty()) {
|
||||
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
|
||||
if (defaultModel.filename().isEmpty()) {
|
||||
emit modelLoadingError(u"Could not find any model to load"_s);
|
||||
return;
|
||||
}
|
||||
m = defaultModel;
|
||||
}
|
||||
loadModel(m);
|
||||
}
|
||||
|
||||
void LlamaCppModel::unloadModel()
|
||||
void LlamaCppModel::releaseModel(bool unload)
|
||||
{
|
||||
if (!isModelLoaded() || m_isServer)
|
||||
return;
|
||||
|
||||
if (!m_forceUnloadModel || !m_shouldBeLoaded)
|
||||
if (unload && m_shouldBeLoaded) {
|
||||
// reloading the model, don't show unloaded status
|
||||
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small positive value
|
||||
} else {
|
||||
emit modelLoadingPercentageChanged(0.0f);
|
||||
else
|
||||
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
|
||||
}
|
||||
|
||||
if (!m_markedForDeletion)
|
||||
saveState();
|
||||
@ -860,33 +872,14 @@ void LlamaCppModel::unloadModel()
|
||||
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
|
||||
if (m_forceUnloadModel) {
|
||||
if (unload) {
|
||||
m_llModelInfo.resetModel(this);
|
||||
m_forceUnloadModel = false;
|
||||
}
|
||||
|
||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||
m_pristineLoadedState = false;
|
||||
}
|
||||
|
||||
void LlamaCppModel::reloadModel()
|
||||
{
|
||||
if (isModelLoaded() && m_forceUnloadModel)
|
||||
unloadModel(); // we unload first if we are forcing an unload
|
||||
|
||||
if (isModelLoaded() || m_isServer)
|
||||
return;
|
||||
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||
#endif
|
||||
const ModelInfo m = modelInfo();
|
||||
if (m.name().isEmpty())
|
||||
loadDefaultModel();
|
||||
else
|
||||
loadModel(m);
|
||||
}
|
||||
|
||||
void LlamaCppModel::generateName()
|
||||
{
|
||||
Q_ASSERT(isModelLoaded());
|
||||
|
@ -109,9 +109,9 @@ public:
|
||||
|
||||
void stopGenerating() override { m_stopGenerating = true; }
|
||||
|
||||
void setShouldBeLoaded(bool b) override;
|
||||
void loadModelAsync(bool reload = false) override;
|
||||
void releaseModelAsync(bool unload = false) override;
|
||||
void requestTrySwitchContext() override;
|
||||
void setForceUnloadModel(bool b) override { m_forceUnloadModel = b; }
|
||||
void setMarkedForDeletion(bool b) override { m_markedForDeletion = b; }
|
||||
|
||||
void setModelInfo(const ModelInfo &info) override;
|
||||
@ -147,14 +147,14 @@ public:
|
||||
|
||||
public Q_SLOTS:
|
||||
bool prompt(const QList<QString> &collectionList, const QString &prompt) override;
|
||||
bool loadDefaultModel() override;
|
||||
bool loadModel(const ModelInfo &modelInfo) override;
|
||||
void modelChangeRequested(const ModelInfo &modelInfo) override;
|
||||
void generateName() override;
|
||||
void processSystemPrompt() override;
|
||||
|
||||
Q_SIGNALS:
|
||||
void shouldBeLoadedChanged();
|
||||
void requestLoadModel(bool reload);
|
||||
void requestReleaseModel(bool unload);
|
||||
|
||||
protected:
|
||||
bool isModelLoaded() const;
|
||||
@ -183,11 +183,10 @@ protected:
|
||||
|
||||
protected Q_SLOTS:
|
||||
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
|
||||
void unloadModel();
|
||||
void reloadModel();
|
||||
void loadModel(bool reload = false);
|
||||
void releaseModel(bool unload = false);
|
||||
void generateQuestions(qint64 elapsed);
|
||||
void handleChatIdChanged(const QString &id);
|
||||
void handleShouldBeLoadedChanged();
|
||||
void handleThreadStarted();
|
||||
void handleForceMetalChanged(bool forceMetal);
|
||||
void handleDeviceChanged();
|
||||
@ -197,11 +196,13 @@ private:
|
||||
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
|
||||
|
||||
protected:
|
||||
ModelBackend::PromptContext m_ctx;
|
||||
// used by Server
|
||||
quint32 m_promptTokens;
|
||||
quint32 m_promptResponseTokens;
|
||||
std::atomic<bool> m_shouldBeLoaded;
|
||||
|
||||
private:
|
||||
ModelBackend::PromptContext m_ctx;
|
||||
std::string m_response;
|
||||
std::string m_nameResponse;
|
||||
QString m_questionResponse;
|
||||
@ -212,9 +213,7 @@ private:
|
||||
QByteArray m_state;
|
||||
QThread m_llmThread;
|
||||
std::atomic<bool> m_stopGenerating;
|
||||
std::atomic<bool> m_shouldBeLoaded;
|
||||
std::atomic<bool> m_restoringFromText; // status indication
|
||||
std::atomic<bool> m_forceUnloadModel;
|
||||
std::atomic<bool> m_markedForDeletion;
|
||||
bool m_isServer;
|
||||
bool m_forceMetal;
|
||||
|
@ -30,9 +30,9 @@ public:
|
||||
|
||||
virtual void stopGenerating() = 0;
|
||||
|
||||
virtual void setShouldBeLoaded(bool b) = 0;
|
||||
virtual void loadModelAsync(bool reload = false) = 0;
|
||||
virtual void releaseModelAsync(bool unload = false) = 0;
|
||||
virtual void requestTrySwitchContext() = 0;
|
||||
virtual void setForceUnloadModel(bool b) = 0;
|
||||
virtual void setMarkedForDeletion(bool b) = 0;
|
||||
|
||||
virtual void setModelInfo(const ModelInfo &info) = 0;
|
||||
@ -45,7 +45,6 @@ public:
|
||||
|
||||
public Q_SLOTS:
|
||||
virtual bool prompt(const QList<QString> &collectionList, const QString &prompt) = 0;
|
||||
virtual bool loadDefaultModel() = 0;
|
||||
virtual bool loadModel(const ModelInfo &modelInfo) = 0;
|
||||
virtual void modelChangeRequested(const ModelInfo &modelInfo) = 0;
|
||||
virtual void generateName() = 0;
|
||||
|
@ -352,7 +352,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
|
||||
emit requestServerNewPromptResponsePair(actualPrompt); // blocks
|
||||
|
||||
// load the new model if necessary
|
||||
setShouldBeLoaded(true);
|
||||
m_shouldBeLoaded = true;
|
||||
|
||||
if (modelInfo.filename().isEmpty()) {
|
||||
std::cerr << "ERROR: couldn't load default model " << modelRequested.toStdString() << std::endl;
|
||||
|
Loading…
Reference in New Issue
Block a user