mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-25 23:13:06 +00:00
replace setShouldBeLoaded with loadModelAsync/releaseModelAsync
Signed-off-by: Jared Van Bortel <jared@nomic.ai>
This commit is contained in:
parent
05bd6042b6
commit
8fd9f01578
@ -74,7 +74,6 @@ void Chat::connectLLM()
|
|||||||
|
|
||||||
connect(this, &Chat::promptRequested, m_llmodel, &LLModel::prompt, Qt::QueuedConnection);
|
connect(this, &Chat::promptRequested, m_llmodel, &LLModel::prompt, Qt::QueuedConnection);
|
||||||
connect(this, &Chat::modelChangeRequested, m_llmodel, &LLModel::modelChangeRequested, Qt::QueuedConnection);
|
connect(this, &Chat::modelChangeRequested, m_llmodel, &LLModel::modelChangeRequested, Qt::QueuedConnection);
|
||||||
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &LLModel::loadDefaultModel, Qt::QueuedConnection);
|
|
||||||
connect(this, &Chat::loadModelRequested, m_llmodel, &LLModel::loadModel, Qt::QueuedConnection);
|
connect(this, &Chat::loadModelRequested, m_llmodel, &LLModel::loadModel, Qt::QueuedConnection);
|
||||||
connect(this, &Chat::generateNameRequested, m_llmodel, &LLModel::generateName, Qt::QueuedConnection);
|
connect(this, &Chat::generateNameRequested, m_llmodel, &LLModel::generateName, Qt::QueuedConnection);
|
||||||
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LLModel::regenerateResponse, Qt::QueuedConnection);
|
connect(this, &Chat::regenerateResponseRequested, m_llmodel, &LLModel::regenerateResponse, Qt::QueuedConnection);
|
||||||
@ -277,25 +276,23 @@ void Chat::markForDeletion()
|
|||||||
void Chat::unloadModel()
|
void Chat::unloadModel()
|
||||||
{
|
{
|
||||||
stopGenerating();
|
stopGenerating();
|
||||||
m_llmodel->setShouldBeLoaded(false);
|
m_llmodel->releaseModelAsync();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Chat::reloadModel()
|
void Chat::reloadModel()
|
||||||
{
|
{
|
||||||
m_llmodel->setShouldBeLoaded(true);
|
m_llmodel->loadModelAsync();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Chat::forceUnloadModel()
|
void Chat::forceUnloadModel()
|
||||||
{
|
{
|
||||||
stopGenerating();
|
stopGenerating();
|
||||||
m_llmodel->setForceUnloadModel(true);
|
m_llmodel->releaseModelAsync(/*unload*/ true);
|
||||||
m_llmodel->setShouldBeLoaded(false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Chat::forceReloadModel()
|
void Chat::forceReloadModel()
|
||||||
{
|
{
|
||||||
m_llmodel->setForceUnloadModel(true);
|
m_llmodel->loadModelAsync(/*reload*/ true);
|
||||||
m_llmodel->setShouldBeLoaded(true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Chat::trySwitchContextOfLoadedModel()
|
void Chat::trySwitchContextOfLoadedModel()
|
||||||
|
@ -145,7 +145,6 @@ Q_SIGNALS:
|
|||||||
void modelChangeRequested(const ModelInfo &modelInfo);
|
void modelChangeRequested(const ModelInfo &modelInfo);
|
||||||
void modelInfoChanged();
|
void modelInfoChanged();
|
||||||
void restoringFromTextChanged();
|
void restoringFromTextChanged();
|
||||||
void loadDefaultModelRequested();
|
|
||||||
void loadModelRequested(const ModelInfo &modelInfo);
|
void loadModelRequested(const ModelInfo &modelInfo);
|
||||||
void generateNameRequested();
|
void generateNameRequested();
|
||||||
void modelLoadingErrorChanged();
|
void modelLoadingErrorChanged();
|
||||||
|
@ -106,7 +106,6 @@ LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer)
|
|||||||
, m_promptTokens(0)
|
, m_promptTokens(0)
|
||||||
, m_restoringFromText(false)
|
, m_restoringFromText(false)
|
||||||
, m_shouldBeLoaded(false)
|
, m_shouldBeLoaded(false)
|
||||||
, m_forceUnloadModel(false)
|
|
||||||
, m_markedForDeletion(false)
|
, m_markedForDeletion(false)
|
||||||
, m_stopGenerating(false)
|
, m_stopGenerating(false)
|
||||||
, m_timer(nullptr)
|
, m_timer(nullptr)
|
||||||
@ -117,8 +116,10 @@ LlamaCppModel::LlamaCppModel(Chat *parent, bool isServer)
|
|||||||
, m_restoreStateFromText(false)
|
, m_restoreStateFromText(false)
|
||||||
{
|
{
|
||||||
moveToThread(&m_llmThread);
|
moveToThread(&m_llmThread);
|
||||||
connect(this, &LlamaCppModel::shouldBeLoadedChanged, this, &LlamaCppModel::handleShouldBeLoadedChanged,
|
connect<void(LlamaCppModel::*)(bool), void(LlamaCppModel::*)(bool)>(
|
||||||
Qt::QueuedConnection); // explicitly queued
|
this, &LlamaCppModel::requestLoadModel, this, &LlamaCppModel::loadModel
|
||||||
|
);
|
||||||
|
connect(this, &LlamaCppModel::requestReleaseModel, this, &LlamaCppModel::releaseModel);
|
||||||
connect(this, &LlamaCppModel::trySwitchContextRequested, this, &LlamaCppModel::trySwitchContextOfLoadedModel,
|
connect(this, &LlamaCppModel::trySwitchContextRequested, this, &LlamaCppModel::trySwitchContextOfLoadedModel,
|
||||||
Qt::QueuedConnection); // explicitly queued
|
Qt::QueuedConnection); // explicitly queued
|
||||||
connect(parent, &Chat::idChanged, this, &LlamaCppModel::handleChatIdChanged);
|
connect(parent, &Chat::idChanged, this, &LlamaCppModel::handleChatIdChanged);
|
||||||
@ -170,8 +171,7 @@ void LlamaCppModel::handleForceMetalChanged(bool forceMetal)
|
|||||||
m_forceMetal = forceMetal;
|
m_forceMetal = forceMetal;
|
||||||
if (isModelLoaded() && m_shouldBeLoaded) {
|
if (isModelLoaded() && m_shouldBeLoaded) {
|
||||||
m_reloadingToChangeVariant = true;
|
m_reloadingToChangeVariant = true;
|
||||||
unloadModel();
|
loadModel(/*reload*/ true);
|
||||||
reloadModel();
|
|
||||||
m_reloadingToChangeVariant = false;
|
m_reloadingToChangeVariant = false;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -181,22 +181,11 @@ void LlamaCppModel::handleDeviceChanged()
|
|||||||
{
|
{
|
||||||
if (isModelLoaded() && m_shouldBeLoaded) {
|
if (isModelLoaded() && m_shouldBeLoaded) {
|
||||||
m_reloadingToChangeVariant = true;
|
m_reloadingToChangeVariant = true;
|
||||||
unloadModel();
|
loadModel(/*reload*/ true);
|
||||||
reloadModel();
|
|
||||||
m_reloadingToChangeVariant = false;
|
m_reloadingToChangeVariant = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool LlamaCppModel::loadDefaultModel()
|
|
||||||
{
|
|
||||||
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
|
|
||||||
if (defaultModel.filename().isEmpty()) {
|
|
||||||
emit modelLoadingError(u"Could not find any model to load"_qs);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return loadModel(defaultModel);
|
|
||||||
}
|
|
||||||
|
|
||||||
void LlamaCppModel::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
void LlamaCppModel::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
||||||
{
|
{
|
||||||
// We're trying to see if the store already has the model fully loaded that we wish to use
|
// We're trying to see if the store already has the model fully loaded that we wish to use
|
||||||
@ -820,13 +809,16 @@ bool LlamaCppModel::promptInternal(const QList<QString> &collectionList, const Q
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void LlamaCppModel::setShouldBeLoaded(bool b)
|
void LlamaCppModel::loadModelAsync(bool reload)
|
||||||
{
|
{
|
||||||
#if defined(DEBUG_MODEL_LOADING)
|
m_shouldBeLoaded = true; // atomic
|
||||||
qDebug() << "setShouldBeLoaded" << m_llmThread.objectName() << b << m_llModelInfo.model.get();
|
emit requestLoadModel(reload);
|
||||||
#endif
|
}
|
||||||
m_shouldBeLoaded = b; // atomic
|
|
||||||
emit shouldBeLoadedChanged();
|
void LlamaCppModel::releaseModelAsync(bool unload)
|
||||||
|
{
|
||||||
|
m_shouldBeLoaded = false; // atomic
|
||||||
|
emit requestReleaseModel(unload);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LlamaCppModel::requestTrySwitchContext()
|
void LlamaCppModel::requestTrySwitchContext()
|
||||||
@ -835,23 +827,43 @@ void LlamaCppModel::requestTrySwitchContext()
|
|||||||
emit trySwitchContextRequested(modelInfo());
|
emit trySwitchContextRequested(modelInfo());
|
||||||
}
|
}
|
||||||
|
|
||||||
void LlamaCppModel::handleShouldBeLoadedChanged()
|
void LlamaCppModel::loadModel(bool reload)
|
||||||
{
|
{
|
||||||
if (m_shouldBeLoaded)
|
Q_ASSERT(m_shouldBeLoaded);
|
||||||
reloadModel();
|
if (m_isServer)
|
||||||
else
|
return; // server managed models directly
|
||||||
unloadModel();
|
|
||||||
|
if (reload)
|
||||||
|
releaseModel(/*unload*/ true);
|
||||||
|
else if (isModelLoaded())
|
||||||
|
return; // already loaded
|
||||||
|
|
||||||
|
#if defined(DEBUG_MODEL_LOADING)
|
||||||
|
qDebug() << "loadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||||
|
#endif
|
||||||
|
ModelInfo m = modelInfo();
|
||||||
|
if (m.name().isEmpty()) {
|
||||||
|
ModelInfo defaultModel = ModelList::globalInstance()->defaultModelInfo();
|
||||||
|
if (defaultModel.filename().isEmpty()) {
|
||||||
|
emit modelLoadingError(u"Could not find any model to load"_s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
m = defaultModel;
|
||||||
|
}
|
||||||
|
loadModel(m);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LlamaCppModel::unloadModel()
|
void LlamaCppModel::releaseModel(bool unload)
|
||||||
{
|
{
|
||||||
if (!isModelLoaded() || m_isServer)
|
if (!isModelLoaded() || m_isServer)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if (!m_forceUnloadModel || !m_shouldBeLoaded)
|
if (unload && m_shouldBeLoaded) {
|
||||||
|
// reloading the model, don't show unloaded status
|
||||||
|
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small positive value
|
||||||
|
} else {
|
||||||
emit modelLoadingPercentageChanged(0.0f);
|
emit modelLoadingPercentageChanged(0.0f);
|
||||||
else
|
}
|
||||||
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min()); // small non-zero positive value
|
|
||||||
|
|
||||||
if (!m_markedForDeletion)
|
if (!m_markedForDeletion)
|
||||||
saveState();
|
saveState();
|
||||||
@ -860,33 +872,14 @@ void LlamaCppModel::unloadModel()
|
|||||||
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (m_forceUnloadModel) {
|
if (unload) {
|
||||||
m_llModelInfo.resetModel(this);
|
m_llModelInfo.resetModel(this);
|
||||||
m_forceUnloadModel = false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
LLModelStore::globalInstance()->releaseModel(std::move(m_llModelInfo));
|
||||||
m_pristineLoadedState = false;
|
m_pristineLoadedState = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
void LlamaCppModel::reloadModel()
|
|
||||||
{
|
|
||||||
if (isModelLoaded() && m_forceUnloadModel)
|
|
||||||
unloadModel(); // we unload first if we are forcing an unload
|
|
||||||
|
|
||||||
if (isModelLoaded() || m_isServer)
|
|
||||||
return;
|
|
||||||
|
|
||||||
#if defined(DEBUG_MODEL_LOADING)
|
|
||||||
qDebug() << "reloadModel" << m_llmThread.objectName() << m_llModelInfo.model.get();
|
|
||||||
#endif
|
|
||||||
const ModelInfo m = modelInfo();
|
|
||||||
if (m.name().isEmpty())
|
|
||||||
loadDefaultModel();
|
|
||||||
else
|
|
||||||
loadModel(m);
|
|
||||||
}
|
|
||||||
|
|
||||||
void LlamaCppModel::generateName()
|
void LlamaCppModel::generateName()
|
||||||
{
|
{
|
||||||
Q_ASSERT(isModelLoaded());
|
Q_ASSERT(isModelLoaded());
|
||||||
|
@ -109,9 +109,9 @@ public:
|
|||||||
|
|
||||||
void stopGenerating() override { m_stopGenerating = true; }
|
void stopGenerating() override { m_stopGenerating = true; }
|
||||||
|
|
||||||
void setShouldBeLoaded(bool b) override;
|
void loadModelAsync(bool reload = false) override;
|
||||||
|
void releaseModelAsync(bool unload = false) override;
|
||||||
void requestTrySwitchContext() override;
|
void requestTrySwitchContext() override;
|
||||||
void setForceUnloadModel(bool b) override { m_forceUnloadModel = b; }
|
|
||||||
void setMarkedForDeletion(bool b) override { m_markedForDeletion = b; }
|
void setMarkedForDeletion(bool b) override { m_markedForDeletion = b; }
|
||||||
|
|
||||||
void setModelInfo(const ModelInfo &info) override;
|
void setModelInfo(const ModelInfo &info) override;
|
||||||
@ -147,14 +147,14 @@ public:
|
|||||||
|
|
||||||
public Q_SLOTS:
|
public Q_SLOTS:
|
||||||
bool prompt(const QList<QString> &collectionList, const QString &prompt) override;
|
bool prompt(const QList<QString> &collectionList, const QString &prompt) override;
|
||||||
bool loadDefaultModel() override;
|
|
||||||
bool loadModel(const ModelInfo &modelInfo) override;
|
bool loadModel(const ModelInfo &modelInfo) override;
|
||||||
void modelChangeRequested(const ModelInfo &modelInfo) override;
|
void modelChangeRequested(const ModelInfo &modelInfo) override;
|
||||||
void generateName() override;
|
void generateName() override;
|
||||||
void processSystemPrompt() override;
|
void processSystemPrompt() override;
|
||||||
|
|
||||||
Q_SIGNALS:
|
Q_SIGNALS:
|
||||||
void shouldBeLoadedChanged();
|
void requestLoadModel(bool reload);
|
||||||
|
void requestReleaseModel(bool unload);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool isModelLoaded() const;
|
bool isModelLoaded() const;
|
||||||
@ -183,11 +183,10 @@ protected:
|
|||||||
|
|
||||||
protected Q_SLOTS:
|
protected Q_SLOTS:
|
||||||
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
|
void trySwitchContextOfLoadedModel(const ModelInfo &modelInfo);
|
||||||
void unloadModel();
|
void loadModel(bool reload = false);
|
||||||
void reloadModel();
|
void releaseModel(bool unload = false);
|
||||||
void generateQuestions(qint64 elapsed);
|
void generateQuestions(qint64 elapsed);
|
||||||
void handleChatIdChanged(const QString &id);
|
void handleChatIdChanged(const QString &id);
|
||||||
void handleShouldBeLoadedChanged();
|
|
||||||
void handleThreadStarted();
|
void handleThreadStarted();
|
||||||
void handleForceMetalChanged(bool forceMetal);
|
void handleForceMetalChanged(bool forceMetal);
|
||||||
void handleDeviceChanged();
|
void handleDeviceChanged();
|
||||||
@ -197,11 +196,13 @@ private:
|
|||||||
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
|
bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
ModelBackend::PromptContext m_ctx;
|
// used by Server
|
||||||
quint32 m_promptTokens;
|
quint32 m_promptTokens;
|
||||||
quint32 m_promptResponseTokens;
|
quint32 m_promptResponseTokens;
|
||||||
|
std::atomic<bool> m_shouldBeLoaded;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
ModelBackend::PromptContext m_ctx;
|
||||||
std::string m_response;
|
std::string m_response;
|
||||||
std::string m_nameResponse;
|
std::string m_nameResponse;
|
||||||
QString m_questionResponse;
|
QString m_questionResponse;
|
||||||
@ -212,9 +213,7 @@ private:
|
|||||||
QByteArray m_state;
|
QByteArray m_state;
|
||||||
QThread m_llmThread;
|
QThread m_llmThread;
|
||||||
std::atomic<bool> m_stopGenerating;
|
std::atomic<bool> m_stopGenerating;
|
||||||
std::atomic<bool> m_shouldBeLoaded;
|
|
||||||
std::atomic<bool> m_restoringFromText; // status indication
|
std::atomic<bool> m_restoringFromText; // status indication
|
||||||
std::atomic<bool> m_forceUnloadModel;
|
|
||||||
std::atomic<bool> m_markedForDeletion;
|
std::atomic<bool> m_markedForDeletion;
|
||||||
bool m_isServer;
|
bool m_isServer;
|
||||||
bool m_forceMetal;
|
bool m_forceMetal;
|
||||||
|
@ -30,9 +30,9 @@ public:
|
|||||||
|
|
||||||
virtual void stopGenerating() = 0;
|
virtual void stopGenerating() = 0;
|
||||||
|
|
||||||
virtual void setShouldBeLoaded(bool b) = 0;
|
virtual void loadModelAsync(bool reload = false) = 0;
|
||||||
|
virtual void releaseModelAsync(bool unload = false) = 0;
|
||||||
virtual void requestTrySwitchContext() = 0;
|
virtual void requestTrySwitchContext() = 0;
|
||||||
virtual void setForceUnloadModel(bool b) = 0;
|
|
||||||
virtual void setMarkedForDeletion(bool b) = 0;
|
virtual void setMarkedForDeletion(bool b) = 0;
|
||||||
|
|
||||||
virtual void setModelInfo(const ModelInfo &info) = 0;
|
virtual void setModelInfo(const ModelInfo &info) = 0;
|
||||||
@ -45,7 +45,6 @@ public:
|
|||||||
|
|
||||||
public Q_SLOTS:
|
public Q_SLOTS:
|
||||||
virtual bool prompt(const QList<QString> &collectionList, const QString &prompt) = 0;
|
virtual bool prompt(const QList<QString> &collectionList, const QString &prompt) = 0;
|
||||||
virtual bool loadDefaultModel() = 0;
|
|
||||||
virtual bool loadModel(const ModelInfo &modelInfo) = 0;
|
virtual bool loadModel(const ModelInfo &modelInfo) = 0;
|
||||||
virtual void modelChangeRequested(const ModelInfo &modelInfo) = 0;
|
virtual void modelChangeRequested(const ModelInfo &modelInfo) = 0;
|
||||||
virtual void generateName() = 0;
|
virtual void generateName() = 0;
|
||||||
|
@ -352,7 +352,7 @@ QHttpServerResponse Server::handleCompletionRequest(const QHttpServerRequest &re
|
|||||||
emit requestServerNewPromptResponsePair(actualPrompt); // blocks
|
emit requestServerNewPromptResponsePair(actualPrompt); // blocks
|
||||||
|
|
||||||
// load the new model if necessary
|
// load the new model if necessary
|
||||||
setShouldBeLoaded(true);
|
m_shouldBeLoaded = true;
|
||||||
|
|
||||||
if (modelInfo.filename().isEmpty()) {
|
if (modelInfo.filename().isEmpty()) {
|
||||||
std::cerr << "ERROR: couldn't load default model " << modelRequested.toStdString() << std::endl;
|
std::cerr << "ERROR: couldn't load default model " << modelRequested.toStdString() << std::endl;
|
||||||
|
Loading…
Reference in New Issue
Block a user