Complete revamp of model loading to allow for more discreet control by

the user of the models loading behavior.

Signed-off-by: Adam Treat <treat.adam@gmail.com>
This commit is contained in:
Adam Treat
2024-02-07 09:37:59 -05:00
committed by AT
parent f2024a1f9e
commit d948a4f2ee
14 changed files with 506 additions and 175 deletions

View File

@@ -62,7 +62,9 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_promptResponseTokens(0)
, m_promptTokens(0)
, m_isRecalc(false)
, m_shouldBeLoaded(true)
, m_shouldBeLoaded(false)
, m_forceUnloadModel(false)
, m_shouldTrySwitchContext(false)
, m_stopGenerating(false)
, m_timer(nullptr)
, m_isServer(isServer)
@@ -76,6 +78,8 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged,
Qt::QueuedConnection); // explicitly queued
connect(this, &ChatLLM::shouldTrySwitchContextChanged, this, &ChatLLM::handleShouldTrySwitchContextChanged,
Qt::QueuedConnection); // explicitly queued
connect(parent, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
connect(&m_llmThread, &QThread::started, this, &ChatLLM::handleThreadStarted);
connect(MySettings::globalInstance(), &MySettings::forceMetalChanged, this, &ChatLLM::handleForceMetalChanged);
@@ -143,6 +147,54 @@ bool ChatLLM::loadDefaultModel()
return loadModel(defaultModel);
}
bool ChatLLM::trySwitchContextOfLoadedModel(const ModelInfo &modelInfo)
{
// We're trying to see if the store already has the model fully loaded that we wish to use
// and if so we just acquire it from the store and switch the context and return true. If the
// store doesn't have it or we're already loaded or in any other case just return false.
// If we're already loaded or a server or we're reloading to change the variant/device or the
// modelInfo is empty, then this should fail
if (isModelLoaded() || m_isServer || m_reloadingToChangeVariant || modelInfo.name().isEmpty()) {
m_shouldTrySwitchContext = false;
emit trySwitchContextOfLoadedModelCompleted(false);
return false;
}
QString filePath = modelInfo.dirpath + modelInfo.filename();
QFileInfo fileInfo(filePath);
m_llModelInfo = LLModelStore::globalInstance()->acquireModel();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "acquired model from store" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
// The store gave us no already loaded model, the wrong type of model, then give it back to the
// store and fail
if (!m_llModelInfo.model || m_llModelInfo.fileInfo != fileInfo) {
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
m_shouldTrySwitchContext = false;
emit trySwitchContextOfLoadedModelCompleted(false);
return false;
}
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
// We should be loaded and now we are
m_shouldBeLoaded = true;
m_shouldTrySwitchContext = false;
// Restore, signal and process
restoreState();
emit modelLoadingPercentageChanged(1.0f);
emit trySwitchContextOfLoadedModelCompleted(true);
processSystemPrompt();
return true;
}
bool ChatLLM::loadModel(const ModelInfo &modelInfo)
{
// This is a complicated method because N different possible threads are interested in the outcome
@@ -170,7 +222,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
#endif
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
emit isModelLoadedChanged(false);
emit modelLoadingPercentageChanged(std::numeric_limits<float>::min());
} else if (!m_isServer) {
// This is a blocking call that tries to retrieve the model we need from the model store.
// If it succeeds, then we just have to restore state. If the store has never had a model
@@ -188,7 +240,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
#endif
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
emit isModelLoadedChanged(false);
emit modelLoadingPercentageChanged(0.0f);
return false;
}
@@ -198,7 +250,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
qDebug() << "store had our model" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
restoreState();
emit isModelLoadedChanged(true);
emit modelLoadingPercentageChanged(1.0f);
setModelInfo(modelInfo);
Q_ASSERT(!m_modelInfo.filename().isEmpty());
if (m_modelInfo.filename().isEmpty())
@@ -261,6 +313,12 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
m_llModelInfo.model = LLModel::Implementation::construct(filePath.toStdString(), buildVariant, n_ctx);
if (m_llModelInfo.model) {
m_llModelInfo.model->setProgressCallback([this](float progress) -> bool {
emit modelLoadingPercentageChanged(progress);
return m_shouldBeLoaded;
});
// Update the settings that a model is being loaded and update the device list
MySettings::globalInstance()->setAttemptModelLoad(filePath);
@@ -354,7 +412,7 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
qDebug() << "modelLoadedChanged" << m_llmThread.objectName();
fflush(stdout);
#endif
emit isModelLoadedChanged(isModelLoaded());
emit modelLoadingPercentageChanged(isModelLoaded() ? 1.0f : 0.0f);
static bool isFirstLoad = true;
if (isFirstLoad) {
@@ -456,6 +514,7 @@ void ChatLLM::setModelInfo(const ModelInfo &modelInfo)
void ChatLLM::modelChangeRequested(const ModelInfo &modelInfo)
{
m_shouldBeLoaded = true;
loadModel(modelInfo);
}
@@ -598,6 +657,12 @@ void ChatLLM::setShouldBeLoaded(bool b)
emit shouldBeLoadedChanged();
}
void ChatLLM::setShouldTrySwitchContext(bool b)
{
m_shouldTrySwitchContext = b; // atomic
emit shouldTrySwitchContextChanged();
}
void ChatLLM::handleShouldBeLoadedChanged()
{
if (m_shouldBeLoaded)
@@ -606,10 +671,10 @@ void ChatLLM::handleShouldBeLoadedChanged()
unloadModel();
}
void ChatLLM::forceUnloadModel()
void ChatLLM::handleShouldTrySwitchContextChanged()
{
m_shouldBeLoaded = false; // atomic
unloadModel();
if (m_shouldTrySwitchContext)
trySwitchContextOfLoadedModel(modelInfo());
}
void ChatLLM::unloadModel()
@@ -617,17 +682,27 @@ void ChatLLM::unloadModel()
if (!isModelLoaded() || m_isServer)
return;
emit modelLoadingPercentageChanged(0.0f);
saveState();
#if defined(DEBUG_MODEL_LOADING)
qDebug() << "unloadModel" << m_llmThread.objectName() << m_llModelInfo.model;
#endif
if (m_forceUnloadModel) {
delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
m_forceUnloadModel = false;
}
LLModelStore::globalInstance()->releaseModel(m_llModelInfo);
m_llModelInfo = LLModelInfo();
emit isModelLoadedChanged(false);
}
void ChatLLM::reloadModel()
{
if (isModelLoaded() && m_forceUnloadModel)
unloadModel(); // we unload first if we are forcing an unload
if (isModelLoaded() || m_isServer)
return;