mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-09-08 11:58:53 +00:00
Complete revamp of model loading to allow for more discreet control by
the user of the models loading behavior. Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
@@ -62,7 +62,9 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
, m_promptResponseTokens(0)
|
||||
, m_promptTokens(0)
|
||||
, m_isRecalc(false)
|
||||
, m_shouldBeLoaded(true)
|
||||
, m_shouldBeLoaded(false)
|
||||
, m_forceUnloadModel(false)
|
||||
, m_shouldTrySwitchContext(false)
|
||||
, m_stopGenerating(false)
|
||||
, m_timer(nullptr)
|
||||
, m_isServer(isServer)
|
||||
@@ -76,6 +78,8 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
|
||||
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
|
||||
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
|
||||
Qt::QueuedConnection); // explicitly queued
|
||||
connect(this, &ChatLLM::shouldTrySwitchContextChanged, this, &ChatLLM::handleShouldTrySwitchContextChanged,
|
||||
Qt::QueuedConnection); // explicitly queued
|
||||
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
|
||||
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
|
||||
connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged);
|
||||
@@ -143,6 +147,54 @@ bool ChatLLM::loadDefaultModel()
|
||||
return loadModel(defaultModel);
|
||||
}
|
||||
|
||||
bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
|
||||
{
|
||||
// We're trying to see if the store already has the model fully loaded that we wish to use
|
||||
// and if so we just acquire it from the store and switch the context and return true. If the
|
||||
// store doesn't have it or we're already loaded or in any other case just return false.
|
||||
|
||||
// If we're already loaded or a server or we're reloading to change the variant/device or the
|
||||
// modelInfo is empty, then this should fail
|
||||
if (isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty()) {
|
||||
m_shouldTrySwitchContext = false;
|
||||
emit trySwitchContextOfLoadedModelCompleted(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
QString filePath = modelInfo.dirpath + modelInfo.filename();
|
||||
QFileInfo fileInfo(filePath);
|
||||
|
||||
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
#endif
|
||||
|
||||
// The store gave us no already loaded model, the wrong type of model, then give it back to the
|
||||
// store and fail
|
||||
if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo) {
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
|
||||
m_llModelInfo = LLModelInfo();
|
||||
m_shouldTrySwitchContext = false;
|
||||
emit trySwitchContextOfLoadedModelCompleted(false);
|
||||
return false;
|
||||
}
|
||||
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
#endif
|
||||
|
||||
// We should be loaded and now we are
|
||||
m_shouldBeLoaded = true;
|
||||
m_shouldTrySwitchContext = false;
|
||||
|
||||
// Restore, signal and process
|
||||
restoreState();
|
||||
emit modelLoadingPercentageChanged(1.0f);
|
||||
emit trySwitchContextOfLoadedModelCompleted(true);
|
||||
processSystemPrompt();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
{
|
||||
// This is a complicated method because N different possible threads are interested in the outcome
|
||||
@@ -170,7 +222,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
#endif
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
emit isModelLoadedChanged(false);
|
||||
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min());
|
||||
} else if (!m_isServer) {
|
||||
// This is a blocking call that tries to retrieve the model we need from the model store.
|
||||
// If it succeeds, then we just have to restore state. If the store has never had a model
|
||||
@@ -188,7 +240,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
#endif
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit isModelLoadedChanged(false);
|
||||
emit modelLoadingPercentageChanged(0.0f);
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -198,7 +250,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
#endif
|
||||
restoreState();
|
||||
emit isModelLoadedChanged(true);
|
||||
emit modelLoadingPercentageChanged(1.0f);
|
||||
setModelInfo(modelInfo);
|
||||
Q_ASSERT(!m_modelInfo.filename().isEmpty());
|
||||
if (m_modelInfo.filename().isEmpty())
|
||||
@@ -261,6 +313,12 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
|
||||
|
||||
if (m_llModelInfo.model) {
|
||||
|
||||
m_llModelInfo.model->setProgressCallback([this](float progress) -> bool {
|
||||
emit modelLoadingPercentageChanged(progress);
|
||||
return m_shouldBeLoaded;
|
||||
});
|
||||
|
||||
// Update the settings that a model is being loaded and update the device list
|
||||
MySettings::globalInstance()->setAttemptModelLoad(filePath);
|
||||
|
||||
@@ -354,7 +412,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
|
||||
qDebug() << "modelLoadedChanged" << m_llmThread.objectName();
|
||||
fflush(stdout);
|
||||
#endif
|
||||
emit isModelLoadedChanged(isModelLoaded());
|
||||
emit modelLoadingPercentageChanged(isModelLoaded() ? 1.0f : 0.0f);
|
||||
|
||||
static bool isFirstLoad = true;
|
||||
if (isFirstLoad) {
|
||||
@@ -456,6 +514,7 @@ void ChatLLM::setModelInfo(const ModelInfo &modelInfo)
|
||||
|
||||
void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo)
|
||||
{
|
||||
m_shouldBeLoaded = true;
|
||||
loadModel(modelInfo);
|
||||
}
|
||||
|
||||
@@ -598,6 +657,12 @@ void ChatLLM::setShouldBeLoaded(bool b)
|
||||
emit shouldBeLoadedChanged();
|
||||
}
|
||||
|
||||
void ChatLLM::setShouldTrySwitchContext(bool b)
|
||||
{
|
||||
m_shouldTrySwitchContext = b; // atomic
|
||||
emit shouldTrySwitchContextChanged();
|
||||
}
|
||||
|
||||
void ChatLLM::handleShouldBeLoadedChanged()
|
||||
{
|
||||
if (m_shouldBeLoaded)
|
||||
@@ -606,10 +671,10 @@ void ChatLLM::handleShouldBeLoadedChanged()
|
||||
unloadModel();
|
||||
}
|
||||
|
||||
void ChatLLM::forceUnloadModel()
|
||||
void ChatLLM::handleShouldTrySwitchContextChanged()
|
||||
{
|
||||
m_shouldBeLoaded = false; // atomic
|
||||
unloadModel();
|
||||
if (m_shouldTrySwitchContext)
|
||||
trySwitchContextOfLoadedModel(modelInfo());
|
||||
}
|
||||
|
||||
void ChatLLM::unloadModel()
|
||||
@@ -617,17 +682,27 @@ void ChatLLM::unloadModel()
|
||||
if (!isModelLoaded() || m_isServer)
|
||||
return;
|
||||
|
||||
emit modelLoadingPercentageChanged(0.0f);
|
||||
saveState();
|
||||
#if defined(DEBUG_MODEL_LOADING)
|
||||
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
|
||||
#endif
|
||||
|
||||
if (m_forceUnloadModel) {
|
||||
delete m_llModelInfo.model;
|
||||
m_llModelInfo.model = nullptr;
|
||||
m_forceUnloadModel = false;
|
||||
}
|
||||
|
||||
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
|
||||
m_llModelInfo = LLModelInfo();
|
||||
emit isModelLoadedChanged(false);
|
||||
}
|
||||
|
||||
void ChatLLM::reloadModel()
|
||||
{
|
||||
if (isModelLoaded() && m_forceUnloadModel)
|
||||
unloadModel(); // we unload first if we are forcing an unload
|
||||
|
||||
if (isModelLoaded() || m_isServer)
|
||||
return;
|
||||
|
||||
|
Reference in New Issue
Block a user