Much better memory mgmt for multi-threaded model loading/unloading.

This commit is contained in:
Adam Treat
2023-05-13 19:05:35 -04:00
committed by AT
parent 2989b74d43
commit ddc24acf33
6 changed files with 243 additions and 74 deletions

View File

@@ -12,6 +12,7 @@ Chat::Chat(QObject *parent)
, m_creationDate(QDateTime::currentSecsSinceEpoch())
, m_llmodel(new ChatLLM(this))
, m_isServer(false)
, m_shouldDeleteLater(false)
{
connectLLM();
}
@@ -25,6 +26,7 @@ Chat::Chat(bool isServer, QObject *parent)
, m_creationDate(QDateTime::currentSecsSinceEpoch())
, m_llmodel(new Server(this))
, m_isServer(true)
, m_shouldDeleteLater(false)
{
connectLLM();
}
@@ -43,6 +45,7 @@ void Chat::connectLLM()
// Should be in different threads
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::handleModelLoadedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
@@ -55,8 +58,6 @@ void Chat::connectLLM()
connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection);
connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection);
connect(this, &Chat::unloadModelRequested, m_llmodel, &ChatLLM::unloadModel, Qt::QueuedConnection);
connect(this, &Chat::reloadModelRequested, m_llmodel, &ChatLLM::reloadModel, Qt::QueuedConnection);
connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
// The following are blocking operations and will block the gui thread, therefore must be fast
@@ -122,6 +123,12 @@ void Chat::handleResponseChanged()
emit responseChanged();
}
void Chat::handleModelLoadedChanged()
{
if (m_shouldDeleteLater)
deleteLater();
}
void Chat::responseStarted()
{
m_responseInProgress = true;
@@ -180,15 +187,26 @@ void Chat::loadModel(const QString &modelName)
emit loadModelRequested(modelName);
}
void Chat::unloadAndDeleteLater()
{
if (!isModelLoaded()) {
deleteLater();
return;
}
m_shouldDeleteLater = true;
unloadModel();
}
void Chat::unloadModel()
{
stopGenerating();
emit unloadModelRequested();
m_llmodel->setShouldBeLoaded(false);
}
void Chat::reloadModel()
{
emit reloadModelRequested(m_savedModelName);
m_llmodel->setShouldBeLoaded(true);
}
void Chat::generatedNameChanged()
@@ -236,12 +254,10 @@ bool Chat::deserialize(QDataStream &stream, int version)
stream >> m_userName;
emit nameChanged();
stream >> m_savedModelName;
// Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
// unfortunately, we cannot deserialize these
if (version < 2 && m_savedModelName.contains("gpt4all-j"))
return false;
if (!m_llmodel->deserialize(stream, version))
return false;
if (!m_chatModel->deserialize(stream, version))